Skip to content

Commit e5ce951

Browse files
authored
fix: Always use cloudpickle for the python objects in cloud plans (#23474)
1 parent 80c34c2 commit e5ce951

File tree

8 files changed

+140
-190
lines changed

8 files changed

+140
-190
lines changed

crates/polars-plan/src/client/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod check;
22

33
use polars_core::error::PolarsResult;
44

5-
use crate::dsl::DslPlan;
5+
use crate::dsl::{DslPlan, PlanSerializationContext};
66

77
/// Prepare the given [`DslPlan`] for execution on Polars Cloud.
88
pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult<Vec<u8>> {
@@ -11,7 +11,12 @@ pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult<Vec<u8>> {
1111

1212
// Serialize the plan.
1313
let mut writer = Vec::new();
14-
dsl.serialize_versioned(&mut writer)?;
14+
dsl.serialize_versioned(
15+
&mut writer,
16+
PlanSerializationContext {
17+
use_cloudpickle: true,
18+
},
19+
)?;
1520

1621
Ok(writer)
1722
}

crates/polars-plan/src/dsl/plan.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ impl Default for DslPlan {
234234
}
235235
}
236236

237+
#[derive(Default, Clone, Copy)]
238+
pub struct PlanSerializationContext {
239+
pub use_cloudpickle: bool,
240+
}
241+
237242
impl DslPlan {
238243
pub fn describe(&self) -> PolarsResult<String> {
239244
Ok(self.clone().to_alp()?.describe())
@@ -269,9 +274,19 @@ impl DslPlan {
269274
}
270275

271276
#[cfg(feature = "serde")]
272-
pub fn serialize_versioned<W: Write>(&self, mut writer: W) -> PolarsResult<()> {
277+
pub fn serialize_versioned<W: Write>(
278+
&self,
279+
mut writer: W,
280+
ctx: PlanSerializationContext,
281+
) -> PolarsResult<()> {
273282
let le_major = DSL_VERSION.0.to_le_bytes();
274283
let le_minor = DSL_VERSION.1.to_le_bytes();
284+
285+
// @GB:
286+
// This is absolute horrendous but serde does not allow for state to passed along with the
287+
// serialization so there is no proper way to do this except replace serde.
288+
polars_utils::pl_serialize::USE_CLOUDPICKLE.set(ctx.use_cloudpickle);
289+
275290
writer.write_all(DSL_MAGIC_BYTES)?;
276291
writer.write_all(&le_major)?;
277292
writer.write_all(&le_minor)?;

crates/polars-plan/src/dsl/python_dsl/python_udf.rs

Lines changed: 26 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ use polars_core::prelude::UnknownKind;
99
use polars_core::schema::Schema;
1010
use polars_utils::pl_str::PlSmallStr;
1111
use pyo3::prelude::*;
12-
use pyo3::pybacked::PyBackedBytes;
13-
use pyo3::types::PyBytes;
1412

1513
use crate::prelude::*;
1614

@@ -54,49 +52,27 @@ impl PythonUdfExpression {
5452

5553
#[cfg(feature = "serde")]
5654
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
57-
// Handle byte mark
58-
5955
use polars_utils::pl_serialize;
60-
debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK));
61-
let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
6256

63-
// Handle pickle metadata
64-
let use_cloudpickle = buf[0];
65-
if use_cloudpickle != 0 {
66-
let ser_py_version = &buf[1..3];
67-
let cur_py_version = *PYTHON3_VERSION;
68-
polars_ensure!(
69-
ser_py_version == cur_py_version,
70-
InvalidOperation:
71-
"current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
72-
(3, cur_py_version[0], cur_py_version[1]),
73-
(3, ser_py_version[0], ser_py_version[1] )
74-
);
57+
if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {
58+
polars_bail!(InvalidOperation: "serialization expected python magic byte mark");
7559
}
76-
let buf = &buf[3..];
60+
let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
7761

7862
// Load UDF metadata
7963
let mut reader = Cursor::new(buf);
8064
let (output_type, is_elementwise, returns_scalar): (Option<DataTypeExpr>, bool, bool) =
8165
pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
8266

83-
let remainder = &buf[reader.position() as usize..];
84-
85-
// Load UDF
86-
Python::with_gil(|py| {
87-
let pickle = PyModule::import(py, "pickle")
88-
.expect("unable to import 'pickle'")
89-
.getattr("loads")
90-
.unwrap();
91-
let arg = (PyBytes::new(py, remainder),);
92-
let python_function = pickle.call1(arg)?;
93-
Ok(Arc::new(Self::new(
94-
python_function.into(),
95-
output_type,
96-
is_elementwise,
97-
returns_scalar,
98-
)) as Arc<dyn ColumnsUdf>)
99-
})
67+
let buf = &buf[reader.position() as usize..];
68+
let python_function = pl_serialize::python_object_deserialize(buf)?;
69+
70+
Ok(Arc::new(Self::new(
71+
python_function,
72+
output_type,
73+
is_elementwise,
74+
returns_scalar,
75+
)))
10076
}
10177
}
10278

@@ -145,48 +121,23 @@ impl ColumnsUdf for PythonUdfExpression {
145121

146122
#[cfg(feature = "serde")]
147123
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
148-
// Write byte marks
149-
150124
use polars_utils::pl_serialize;
125+
126+
// Write byte marks
151127
buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
152128

153-
Python::with_gil(|py| {
154-
// Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
155-
let pickle = PyModule::import(py, "pickle")
156-
.expect("unable to import 'pickle'")
157-
.getattr("dumps")
158-
.unwrap();
159-
let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
160-
let (dumped, use_cloudpickle) = match pickle_result {
161-
Ok(dumped) => (dumped, false),
162-
Err(_) => {
163-
let cloudpickle = PyModule::import(py, "cloudpickle")?
164-
.getattr("dumps")
165-
.unwrap();
166-
let dumped = cloudpickle.call1((self.python_function.clone_ref(py),))?;
167-
(dumped, true)
168-
},
169-
};
170-
171-
// Write pickle metadata
172-
buf.push(use_cloudpickle as u8);
173-
buf.extend_from_slice(&*PYTHON3_VERSION);
174-
175-
// Write UDF metadata
176-
pl_serialize::serialize_into_writer::<_, _, true>(
177-
&mut *buf,
178-
&(
179-
self.output_type.clone(),
180-
self.is_elementwise,
181-
self.returns_scalar,
182-
),
183-
)?;
184-
185-
// Write UDF
186-
let dumped = dumped.extract::<PyBackedBytes>().unwrap();
187-
buf.extend_from_slice(&dumped);
188-
Ok(())
189-
})
129+
// Write UDF metadata
130+
pl_serialize::serialize_into_writer::<_, _, true>(
131+
&mut *buf,
132+
&(
133+
self.output_type.clone(),
134+
self.is_elementwise,
135+
self.returns_scalar,
136+
),
137+
)?;
138+
139+
pl_serialize::python_object_serialize(&self.python_function, buf)?;
140+
Ok(())
190141
}
191142
}
192143

crates/polars-python/src/lazyframe/serde.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ impl PyLazyFrame {
1515
fn serialize_binary(&self, py: Python<'_>, py_f: PyObject) -> PyResult<()> {
1616
let file = get_file_like(py_f, true)?;
1717
let writer = BufWriter::new(file);
18-
py.enter_polars(|| self.ldf.logical_plan.serialize_versioned(writer))
18+
py.enter_polars(|| {
19+
self.ldf
20+
.logical_plan
21+
.serialize_versioned(writer, Default::default())
22+
})
1923
}
2024

2125
/// Serialize into a JSON string.

crates/polars-utils/src/pl_serialize.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,81 @@ where
215215
}
216216
}
217217

218+
thread_local! {
219+
pub static USE_CLOUDPICKLE: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
220+
}
221+
222+
#[cfg(feature = "python")]
223+
pub fn python_object_serialize(
224+
pyobj: &pyo3::Py<pyo3::PyAny>,
225+
buf: &mut Vec<u8>,
226+
) -> PolarsResult<()> {
227+
use pyo3::Python;
228+
use pyo3::pybacked::PyBackedBytes;
229+
use pyo3::types::{PyAnyMethods, PyModule};
230+
231+
use crate::python_function::PYTHON3_VERSION;
232+
233+
let mut use_cloudpickle = USE_CLOUDPICKLE.get();
234+
let dumped = Python::with_gil(|py| {
235+
// Pickle with whatever pickling method was selected.
236+
if use_cloudpickle {
237+
let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
238+
cloudpickle.call1((pyobj.clone_ref(py),))?
239+
} else {
240+
let pickle = PyModule::import(py, "pickle")?.getattr("dumps")?;
241+
match pickle.call1((pyobj.clone_ref(py),)) {
242+
Ok(dumped) => dumped,
243+
Err(_) => {
244+
use_cloudpickle = true;
245+
let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
246+
cloudpickle.call1((pyobj.clone_ref(py),))?
247+
},
248+
}
249+
}
250+
.extract::<PyBackedBytes>()
251+
})?;
252+
253+
// Write pickle metadata
254+
buf.push(use_cloudpickle as u8);
255+
buf.extend_from_slice(&*PYTHON3_VERSION);
256+
257+
// Write UDF
258+
buf.extend_from_slice(&dumped);
259+
Ok(())
260+
}
261+
262+
#[cfg(feature = "python")]
263+
pub fn python_object_deserialize(buf: &[u8]) -> PolarsResult<pyo3::Py<pyo3::PyAny>> {
264+
use polars_error::polars_ensure;
265+
use pyo3::Python;
266+
use pyo3::types::{PyAnyMethods, PyBytes, PyModule};
267+
268+
use crate::python_function::PYTHON3_VERSION;
269+
270+
// Handle pickle metadata
271+
let use_cloudpickle = buf[0] != 0;
272+
if use_cloudpickle {
273+
let ser_py_version = &buf[1..3];
274+
let cur_py_version = *PYTHON3_VERSION;
275+
polars_ensure!(
276+
ser_py_version == cur_py_version,
277+
InvalidOperation:
278+
"current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
279+
(3, cur_py_version[0], cur_py_version[1]),
280+
(3, ser_py_version[0], ser_py_version[1] )
281+
);
282+
}
283+
let buf = &buf[3..];
284+
285+
Python::with_gil(|py| {
286+
let loads = PyModule::import(py, "pickle")?.getattr("loads")?;
287+
let arg = (PyBytes::new(py, buf),);
288+
let python_function = loads.call1(arg)?;
289+
Ok(python_function.into())
290+
})
291+
}
292+
218293
#[cfg(test)]
219294
mod tests {
220295
#[test]

0 commit comments

Comments
 (0)