@@ -77,6 +77,7 @@ fn build_group_by_fallback(
77
77
/// Such an expression is defined as the elementwise combination of scalar
78
78
/// aggregations of elementwise combinations of the input columns / scalar literals.
79
79
#[ recursive]
80
+ #[ allow( clippy:: too_many_arguments) ]
80
81
fn try_lower_elementwise_scalar_agg_expr (
81
82
expr : Node ,
82
83
outer_name : Option < PlSmallStr > ,
@@ -85,6 +86,7 @@ fn try_lower_elementwise_scalar_agg_expr(
85
86
expr_arena : & mut Arena < AExpr > ,
86
87
agg_exprs : & mut Vec < ExprIR > ,
87
88
uniq_input_exprs : & mut PlIndexMap < u32 , PlSmallStr > ,
89
+ uniq_agg_exprs : & mut PlIndexMap < u32 , PlSmallStr > ,
88
90
) -> Option < Node > {
89
91
// Helper macro to simplify recursive calls.
90
92
macro_rules! lower_rec {
@@ -97,6 +99,7 @@ fn try_lower_elementwise_scalar_agg_expr(
97
99
expr_arena,
98
100
agg_exprs,
99
101
uniq_input_exprs,
102
+ uniq_agg_exprs,
100
103
)
101
104
} ;
102
105
}
@@ -219,34 +222,44 @@ fn try_lower_elementwise_scalar_agg_expr(
219
222
| IRAggExpr :: Var ( input, ..)
220
223
| IRAggExpr :: Std ( input, ..)
221
224
| IRAggExpr :: Count ( input, ..) => {
222
- if is_input_independent ( * input, expr_arena, expr_cache) {
225
+ let agg = agg. clone ( ) ;
226
+ let input = * input;
227
+ if is_input_independent ( input, expr_arena, expr_cache) {
223
228
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
224
229
// it is not a scalar we need to return expr.implode().
225
230
return None ;
226
231
}
227
232
228
- if !is_elementwise_rec_cached ( * input, expr_arena, expr_cache) {
233
+ if !is_elementwise_rec_cached ( input, expr_arena, expr_cache) {
229
234
return None ;
230
235
}
231
236
232
- let mut trans_agg = agg. clone ( ) ;
233
- let input_id = expr_merger. get_uniq_id ( * input) . unwrap ( ) ;
234
- let input_col = uniq_input_exprs
235
- . entry ( input_id)
236
- . or_insert_with ( unique_column_name)
237
+ let agg_id = expr_merger. get_uniq_id ( expr) . unwrap ( ) ;
238
+ let name = uniq_agg_exprs
239
+ . entry ( agg_id)
240
+ . or_insert_with ( || {
241
+ let mut trans_agg = agg;
242
+ let input_id = expr_merger. get_uniq_id ( input) . unwrap ( ) ;
243
+ let input_col = uniq_input_exprs
244
+ . entry ( input_id)
245
+ . or_insert_with ( unique_column_name)
246
+ . clone ( ) ;
247
+ let input_col_node = expr_arena. add ( AExpr :: Column ( input_col. clone ( ) ) ) ;
248
+ trans_agg. set_input ( input_col_node) ;
249
+ let trans_agg_node = expr_arena. add ( AExpr :: Agg ( trans_agg) ) ;
250
+
251
+ // Add to aggregation expressions and replace with a reference to its output.
252
+ let agg_expr = if let Some ( name) = outer_name {
253
+ ExprIR :: new ( trans_agg_node, OutputName :: Alias ( name) )
254
+ } else {
255
+ ExprIR :: new ( trans_agg_node, OutputName :: Alias ( unique_column_name ( ) ) )
256
+ } ;
257
+ agg_exprs. push ( agg_expr. clone ( ) ) ;
258
+ agg_expr. output_name ( ) . clone ( )
259
+ } )
237
260
. clone ( ) ;
238
- let input_col_node = expr_arena. add ( AExpr :: Column ( input_col. clone ( ) ) ) ;
239
- trans_agg. set_input ( input_col_node) ;
240
- let trans_agg_node = expr_arena. add ( AExpr :: Agg ( trans_agg) ) ;
241
-
242
- // Add to aggregation expressions and replace with a reference to its output.
243
- let agg_expr = if let Some ( name) = outer_name {
244
- ExprIR :: new ( trans_agg_node, OutputName :: Alias ( name) )
245
- } else {
246
- ExprIR :: new ( trans_agg_node, OutputName :: Alias ( unique_column_name ( ) ) )
247
- } ;
248
- let result_node = expr_arena. add ( AExpr :: Column ( agg_expr. output_name ( ) . clone ( ) ) ) ;
249
- agg_exprs. push ( agg_expr) ;
261
+
262
+ let result_node = expr_arena. add ( AExpr :: Column ( name) ) ;
250
263
Some ( result_node)
251
264
} ,
252
265
IRAggExpr :: Median ( ..)
@@ -338,6 +351,8 @@ fn try_build_streaming_group_by(
338
351
let trans_output_node = expr_arena. add ( AExpr :: Column ( uniq_name) ) ;
339
352
trans_output_exprs. push ( ExprIR :: new ( trans_output_node, output_name) ) ;
340
353
}
354
+
355
+ let mut uniq_agg_exprs = PlIndexMap :: new ( ) ;
341
356
for agg in aggs {
342
357
let trans_node = try_lower_elementwise_scalar_agg_expr (
343
358
agg. node ( ) ,
@@ -347,6 +362,7 @@ fn try_build_streaming_group_by(
347
362
expr_arena,
348
363
& mut trans_agg_exprs,
349
364
& mut uniq_input_exprs,
365
+ & mut uniq_agg_exprs,
350
366
) ?;
351
367
let output_name = OutputName :: Alias ( agg. output_name ( ) . clone ( ) ) ;
352
368
trans_output_exprs. push ( ExprIR :: new ( trans_node, output_name) ) ;
0 commit comments