Skip to content

Commit 0fa7141

Browse files
authored
perf: Improve streaming groupby CSE (#23092)
1 parent 665a202 commit 0fa7141

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

crates/polars-stream/src/physical_plan/lower_group_by.rs

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ fn build_group_by_fallback(
7777
/// Such an expression is defined as the elementwise combination of scalar
7878
/// aggregations of elementwise combinations of the input columns / scalar literals.
7979
#[recursive]
80+
#[allow(clippy::too_many_arguments)]
8081
fn try_lower_elementwise_scalar_agg_expr(
8182
expr: Node,
8283
outer_name: Option<PlSmallStr>,
@@ -85,6 +86,7 @@ fn try_lower_elementwise_scalar_agg_expr(
8586
expr_arena: &mut Arena<AExpr>,
8687
agg_exprs: &mut Vec<ExprIR>,
8788
uniq_input_exprs: &mut PlIndexMap<u32, PlSmallStr>,
89+
uniq_agg_exprs: &mut PlIndexMap<u32, PlSmallStr>,
8890
) -> Option<Node> {
8991
// Helper macro to simplify recursive calls.
9092
macro_rules! lower_rec {
@@ -97,6 +99,7 @@ fn try_lower_elementwise_scalar_agg_expr(
9799
expr_arena,
98100
agg_exprs,
99101
uniq_input_exprs,
102+
uniq_agg_exprs,
100103
)
101104
};
102105
}
@@ -219,34 +222,44 @@ fn try_lower_elementwise_scalar_agg_expr(
219222
| IRAggExpr::Var(input, ..)
220223
| IRAggExpr::Std(input, ..)
221224
| IRAggExpr::Count(input, ..) => {
222-
if is_input_independent(*input, expr_arena, expr_cache) {
225+
let agg = agg.clone();
226+
let input = *input;
227+
if is_input_independent(input, expr_arena, expr_cache) {
223228
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
224229
// it is not a scalar we need to return expr.implode().
225230
return None;
226231
}
227232

228-
if !is_elementwise_rec_cached(*input, expr_arena, expr_cache) {
233+
if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {
229234
return None;
230235
}
231236

232-
let mut trans_agg = agg.clone();
233-
let input_id = expr_merger.get_uniq_id(*input).unwrap();
234-
let input_col = uniq_input_exprs
235-
.entry(input_id)
236-
.or_insert_with(unique_column_name)
237+
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
238+
let name = uniq_agg_exprs
239+
.entry(agg_id)
240+
.or_insert_with(|| {
241+
let mut trans_agg = agg;
242+
let input_id = expr_merger.get_uniq_id(input).unwrap();
243+
let input_col = uniq_input_exprs
244+
.entry(input_id)
245+
.or_insert_with(unique_column_name)
246+
.clone();
247+
let input_col_node = expr_arena.add(AExpr::Column(input_col.clone()));
248+
trans_agg.set_input(input_col_node);
249+
let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg));
250+
251+
// Add to aggregation expressions and replace with a reference to its output.
252+
let agg_expr = if let Some(name) = outer_name {
253+
ExprIR::new(trans_agg_node, OutputName::Alias(name))
254+
} else {
255+
ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))
256+
};
257+
agg_exprs.push(agg_expr.clone());
258+
agg_expr.output_name().clone()
259+
})
237260
.clone();
238-
let input_col_node = expr_arena.add(AExpr::Column(input_col.clone()));
239-
trans_agg.set_input(input_col_node);
240-
let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg));
241-
242-
// Add to aggregation expressions and replace with a reference to its output.
243-
let agg_expr = if let Some(name) = outer_name {
244-
ExprIR::new(trans_agg_node, OutputName::Alias(name))
245-
} else {
246-
ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))
247-
};
248-
let result_node = expr_arena.add(AExpr::Column(agg_expr.output_name().clone()));
249-
agg_exprs.push(agg_expr);
261+
262+
let result_node = expr_arena.add(AExpr::Column(name));
250263
Some(result_node)
251264
},
252265
IRAggExpr::Median(..)
@@ -338,6 +351,8 @@ fn try_build_streaming_group_by(
338351
let trans_output_node = expr_arena.add(AExpr::Column(uniq_name));
339352
trans_output_exprs.push(ExprIR::new(trans_output_node, output_name));
340353
}
354+
355+
let mut uniq_agg_exprs = PlIndexMap::new();
341356
for agg in aggs {
342357
let trans_node = try_lower_elementwise_scalar_agg_expr(
343358
agg.node(),
@@ -347,6 +362,7 @@ fn try_build_streaming_group_by(
347362
expr_arena,
348363
&mut trans_agg_exprs,
349364
&mut uniq_input_exprs,
365+
&mut uniq_agg_exprs,
350366
)?;
351367
let output_name = OutputName::Alias(agg.output_name().clone());
352368
trans_output_exprs.push(ExprIR::new(trans_node, output_name));

0 commit comments

Comments
 (0)