Skip to content

Commit 123f846

Browse files
committed
refactor: Introduce Row{Encode,Decode} as FunctionExpr
1 parent e24d694 commit 123f846

File tree

20 files changed

+632
-294
lines changed

20 files changed

+632
-294
lines changed

crates/polars-core/src/chunked_array/ops/row_encode.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,55 @@ pub fn _get_rows_encoded_ca_unordered(
244244
_get_rows_encoded_unordered(by)
245245
.map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array()))
246246
}
247+
248+
pub fn row_encoding_decode(
249+
ca: &BinaryOffsetChunked,
250+
fields: &[Field],
251+
opts: &[RowEncodingOptions],
252+
) -> PolarsResult<StructChunked> {
253+
let (ctxts, dtypes) = fields
254+
.iter()
255+
.map(|f| {
256+
(
257+
get_row_encoding_context(f.dtype()),
258+
f.dtype().to_physical().to_arrow(CompatLevel::newest()),
259+
)
260+
})
261+
.collect::<(Vec<_>, Vec<_>)>();
262+
263+
let struct_arrow_dtype = ArrowDataType::Struct(
264+
fields
265+
.iter()
266+
.map(|v| v.to_physical().to_arrow(CompatLevel::newest()))
267+
.collect(),
268+
);
269+
270+
let mut rows = Vec::new();
271+
let chunks = ca
272+
.downcast_iter()
273+
.map(|array| {
274+
let decoded_arrays = unsafe {
275+
polars_row::decode::decode_rows_from_binary(
276+
array, &opts, &ctxts, &dtypes, &mut rows,
277+
)
278+
};
279+
assert_eq!(decoded_arrays.len(), fields.len());
280+
281+
StructArray::new(
282+
struct_arrow_dtype.clone(),
283+
array.len(),
284+
decoded_arrays,
285+
None,
286+
)
287+
.to_boxed()
288+
})
289+
.collect::<Vec<_>>();
290+
291+
Ok(unsafe {
292+
StructChunked::from_chunks_and_dtype(
293+
ca.name().clone(),
294+
chunks,
295+
DataType::Struct(fields.to_vec()),
296+
)
297+
})
298+
}

crates/polars-core/src/datatypes/field.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ impl Field {
121121
pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowField {
122122
self.dtype.to_arrow_field(self.name.clone(), compat_level)
123123
}
124+
125+
pub fn to_physical(&self) -> Field {
126+
Self {
127+
name: self.name.clone(),
128+
dtype: self.dtype().to_physical(),
129+
}
130+
}
124131
}
125132

126133
impl AsRef<DataType> for Field {

crates/polars-core/src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub use arrow::datatypes::{ArrowSchema, Field as ArrowField};
77
pub use arrow::legacy::prelude::*;
88
pub(crate) use arrow::trusted_len::TrustedLen;
99
pub use polars_compute::rolling::{QuantileMethod, RollingFnParams, RollingVarParams};
10+
pub use polars_row::RowEncodingOptions;
1011
pub use polars_utils::aliases::*;
1112
pub use polars_utils::index::{ChunkId, IdxSize, NullableIdxSize};
1213
pub use polars_utils::pl_str::PlSmallStr;

crates/polars-plan/src/dsl/function_expr/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ pub enum FunctionExpr {
363363
#[cfg(feature = "reinterpret")]
364364
Reinterpret(bool),
365365
ExtendConstant,
366+
367+
RowEncode(RowEncodingVariant),
368+
RowDecode(Vec<(PlSmallStr, DataTypeExpr)>, RowEncodingVariant),
366369
}
367370

368371
impl Hash for FunctionExpr {
@@ -650,6 +653,12 @@ impl Hash for FunctionExpr {
650653
ExtendConstant => {},
651654
#[cfg(feature = "top_k")]
652655
TopKBy { descending } => descending.hash(state),
656+
657+
RowEncode(variants) => variants.hash(state),
658+
RowDecode(fs, variants) => {
659+
fs.hash(state);
660+
variants.hash(state);
661+
},
653662
}
654663
}
655664
}
@@ -848,6 +857,9 @@ impl Display for FunctionExpr {
848857
#[cfg(feature = "reinterpret")]
849858
Reinterpret(_) => "reinterpret",
850859
ExtendConstant => "extend_constant",
860+
861+
RowEncode(..) => "row_encode",
862+
RowDecode(..) => "row_decode",
851863
};
852864
write!(f, "{s}")
853865
}

crates/polars-plan/src/plans/aexpr/function_expr/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ mod rolling;
5757
mod rolling_by;
5858
#[cfg(feature = "round_series")]
5959
mod round;
60+
mod row_encode;
6061
#[cfg(feature = "row_hash")]
6162
mod row_hash;
6263
pub(super) mod schema;
@@ -112,6 +113,7 @@ pub use self::range::IRRangeFunction;
112113
pub use self::rolling::IRRollingFunction;
113114
#[cfg(feature = "rolling_window_by")]
114115
pub use self::rolling_by::IRRollingFunctionBy;
116+
pub use self::row_encode::RowEncodingVariant;
115117
#[cfg(feature = "strings")]
116118
pub use self::strings::IRStringFunction;
117119
#[cfg(feature = "dtype-struct")]
@@ -413,6 +415,9 @@ pub enum IRFunctionExpr {
413415
#[cfg(feature = "reinterpret")]
414416
Reinterpret(bool),
415417
ExtendConstant,
418+
419+
RowEncode(RowEncodingVariant),
420+
RowDecode(Vec<Field>, RowEncodingVariant),
416421
}
417422

418423
impl Hash for IRFunctionExpr {
@@ -705,6 +710,12 @@ impl Hash for IRFunctionExpr {
705710
ExtendConstant => {},
706711
#[cfg(feature = "top_k")]
707712
TopKBy { descending } => descending.hash(state),
713+
714+
RowEncode(variants) => variants.hash(state),
715+
RowDecode(fs, variants) => {
716+
fs.hash(state);
717+
variants.hash(state);
718+
},
708719
}
709720
}
710721
}
@@ -907,6 +918,9 @@ impl Display for IRFunctionExpr {
907918
#[cfg(feature = "reinterpret")]
908919
Reinterpret(_) => "reinterpret",
909920
ExtendConstant => "extend_constant",
921+
922+
RowEncode(..) => "row_encode",
923+
RowDecode(..) => "row_decode",
910924
};
911925
write!(f, "{s}")
912926
}
@@ -1384,6 +1398,11 @@ impl From<IRFunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
13841398
#[cfg(feature = "reinterpret")]
13851399
Reinterpret(signed) => map!(dispatch::reinterpret, signed),
13861400
ExtendConstant => map_as_slice!(dispatch::extend_constant),
1401+
1402+
RowEncode(variants) => map_as_slice!(row_encode::encode, variants.clone()),
1403+
RowDecode(fs, variants) => {
1404+
map_as_slice!(row_encode::decode, fs.clone(), variants.clone())
1405+
},
13871406
}
13881407
}
13891408
}
@@ -1590,6 +1609,8 @@ impl IRFunctionExpr {
15901609
#[cfg(feature = "reinterpret")]
15911610
F::Reinterpret(_) => FunctionOptions::elementwise(),
15921611
F::ExtendConstant => FunctionOptions::groupwise(),
1612+
1613+
F::RowEncode(..) | F::RowDecode(..) => FunctionOptions::elementwise(),
15931614
}
15941615
}
15951616
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use polars_core::prelude::row_encode::{
2+
_get_rows_encoded_ca, _get_rows_encoded_ca_unordered, encode_rows_unordered,
3+
row_encoding_decode,
4+
};
5+
use polars_core::prelude::{Column, Field, IntoColumn, RowEncodingOptions};
6+
use polars_error::PolarsResult;
7+
use polars_utils::pl_str::PlSmallStr;
8+
9+
#[derive(Clone, Debug, Hash, PartialEq)]
10+
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11+
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
12+
pub enum RowEncodingVariant {
13+
Unordered,
14+
Ordered {
15+
descending: Option<Vec<bool>>,
16+
nulls_last: Option<Vec<bool>>,
17+
},
18+
}
19+
20+
pub fn encode(c: &mut [Column], variant: RowEncodingVariant) -> PolarsResult<Column> {
21+
let name = PlSmallStr::from_static("row_encoded");
22+
match variant {
23+
RowEncodingVariant::Unordered => _get_rows_encoded_ca_unordered(name, c),
24+
RowEncodingVariant::Ordered {
25+
descending,
26+
nulls_last,
27+
} => {
28+
let descending = descending.unwrap_or_else(|| vec![false; c.len()]);
29+
let nulls_last = nulls_last.unwrap_or_else(|| vec![false; c.len()]);
30+
31+
assert_eq!(c.len(), descending.len());
32+
assert_eq!(c.len(), nulls_last.len());
33+
34+
_get_rows_encoded_ca(name, c, &descending, &nulls_last)
35+
},
36+
}
37+
.map(IntoColumn::into_column)
38+
}
39+
40+
pub fn decode(
41+
c: &mut [Column],
42+
fields: Vec<Field>,
43+
variant: RowEncodingVariant,
44+
) -> PolarsResult<Column> {
45+
assert_eq!(c.len(), 1);
46+
let ca = c[0].binary_offset()?;
47+
48+
let mut opts = Vec::with_capacity(fields.len());
49+
match variant {
50+
RowEncodingVariant::Unordered => opts.extend(std::iter::repeat_n(
51+
RowEncodingOptions::new_unsorted(),
52+
fields.len(),
53+
)),
54+
RowEncodingVariant::Ordered {
55+
descending,
56+
nulls_last,
57+
} => {
58+
let descending = descending.unwrap_or_else(|| vec![false; fields.len()]);
59+
let nulls_last = nulls_last.unwrap_or_else(|| vec![false; fields.len()]);
60+
61+
assert_eq!(fields.len(), descending.len());
62+
assert_eq!(fields.len(), nulls_last.len());
63+
64+
opts.extend(
65+
descending
66+
.into_iter()
67+
.zip(nulls_last)
68+
.map(|(d, n)| RowEncodingOptions::new_sorted(d, n)),
69+
)
70+
},
71+
}
72+
73+
row_encoding_decode(ca, &fields, &opts).map(IntoColumn::into_column)
74+
}

crates/polars-plan/src/plans/aexpr/function_expr/schema.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ impl IRFunctionExpr {
428428
mapper.with_dtype(dt)
429429
},
430430
ExtendConstant => mapper.with_same_dtype(),
431+
432+
RowEncode(_) => mapper.try_map_field(|_| Ok(Field::new(PlSmallStr::from_static("row-encode"), DataType::BinaryOffset))),
433+
RowDecode(fields, _) => mapper.with_dtype(DataType::Struct(fields.to_vec())),
431434
}
432435
}
433436

crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ fn function_input_wildcard_expansion(function: &FunctionExpr) -> FunctionExpansi
8383
| F::ReduceHorizontal { .. }
8484
| F::SumHorizontal { .. }
8585
| F::MeanHorizontal { .. }
86+
| F::RowEncode(..)
8687
);
8788
let mut allow_empty_inputs = matches!(
8889
function,

crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,14 @@ pub(super) fn convert_functions(
996996
polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
997997
I::ExtendConstant
998998
},
999+
1000+
F::RowEncode(v) => I::RowEncode(v),
1001+
F::RowDecode(fs, v) => I::RowDecode(
1002+
fs.into_iter()
1003+
.map(|(name, dt_expr)| Ok(Field::new(name, dt_expr.into_datatype(ctx.schema)?)))
1004+
.collect::<PolarsResult<Vec<_>>>()?,
1005+
v,
1006+
),
9991007
};
10001008

10011009
let mut options = ir_function.function_options();

crates/polars-plan/src/plans/conversion/ir_to_dsl.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,12 @@ pub fn ir_function_to_dsl(input: Vec<Expr>, function: IRFunctionExpr) -> Expr {
10731073
#[cfg(feature = "reinterpret")]
10741074
IF::Reinterpret(v) => F::Reinterpret(v),
10751075
IF::ExtendConstant => F::ExtendConstant,
1076+
1077+
IF::RowEncode(v) => F::RowEncode(v),
1078+
IF::RowDecode(fs, v) => F::RowDecode(
1079+
fs.into_iter().map(|f| (f.name, f.dtype.into())).collect(),
1080+
v,
1081+
),
10761082
};
10771083

10781084
Expr::Function { input, function }

0 commit comments

Comments
 (0)