@@ -132,18 +132,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
132132
133133 let mut patch = MirPatch :: new ( body) ;
134134
135- // create temp to store second discriminant in, `_s` in example above
136- let second_discriminant_temp =
137- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
135+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
136+ // create temp to store second discriminant in, `_s` in example above
137+ let second_discriminant_temp =
138+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
138139
139- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
140+ patch. add_statement (
141+ parent_end,
142+ StatementKind :: StorageLive ( second_discriminant_temp) ,
143+ ) ;
140144
141- // create assignment of discriminant
142- patch. add_assign (
143- parent_end,
144- Place :: from ( second_discriminant_temp) ,
145- Rvalue :: Discriminant ( opt_data. child_place ) ,
146- ) ;
145+ // create assignment of discriminant
146+ patch. add_assign (
147+ parent_end,
148+ Place :: from ( second_discriminant_temp) ,
149+ Rvalue :: Discriminant ( opt_data. child_place ) ,
150+ ) ;
151+ (
152+ Some ( second_discriminant_temp) ,
153+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
154+ )
155+ } else {
156+ ( None , Operand :: Copy ( opt_data. child_place ) )
157+ } ;
147158
148159 // create temp to store inequality comparison between the two discriminants, `_t` in
149160 // example above
@@ -152,11 +163,9 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
152163 let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
153164 patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
154165
155- // create inequality comparison between the two discriminants
156- let comp_rvalue = Rvalue :: BinaryOp (
157- nequal,
158- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
159- ) ;
166+ // create inequality comparison
167+ let comp_rvalue =
168+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
160169 patch. add_statement (
161170 parent_end,
162171 StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -192,8 +201,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
192201 TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
193202 ) ;
194203
195- // generate StorageDead for the second_discriminant_temp not in use anymore
196- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
204+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
205+ // generate StorageDead for the second_discriminant_temp not in use anymore
206+ patch. add_statement (
207+ parent_end,
208+ StatementKind :: StorageDead ( second_discriminant_temp) ,
209+ ) ;
210+ }
197211
198212 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
199213 // the switch
@@ -221,6 +235,7 @@ struct OptimizationData<'tcx> {
221235 child_place : Place < ' tcx > ,
222236 child_ty : Ty < ' tcx > ,
223237 child_source : SourceInfo ,
238+ need_hoist_discriminant : bool ,
224239}
225240
226241fn evaluate_candidate < ' tcx > (
@@ -234,70 +249,128 @@ fn evaluate_candidate<'tcx>(
234249 return None ;
235250 } ;
236251 let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
237- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
238- // Someone could write code like this:
239- // ```rust
240- // let Q = val;
241- // if discriminant(P) == otherwise {
242- // let ptr = &mut Q as *mut _ as *mut u8;
243- // // It may be difficult for us to effectively determine whether values are valid.
244- // // Invalid values can come from all sorts of corners.
245- // unsafe { *ptr = 10; }
246- // }
247- //
248- // match P {
249- // A => match Q {
250- // A => {
251- // // code
252- // }
253- // _ => {
254- // // don't use Q
255- // }
256- // }
257- // _ => {
258- // // don't use Q
259- // }
260- // };
261- // ```
262- //
263- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
264- // invalid value, which is UB.
265- // In order to fix this, **we would either need to show that the discriminant computation of
266- // `place` is computed in all branches**.
267- // FIXME(#95162) For the moment, we adopt a conservative approach and
268- // consider only the `otherwise` branch has no statements and an unreachable terminator.
269- return None ;
270- }
271252 let ( _, child) = targets. iter ( ) . next ( ) ?;
272- let child_terminator = & bbs[ child] . terminator ( ) ;
273- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
274- & child_terminator. kind
253+
254+ let Terminator {
255+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
256+ source_info,
257+ } = bbs[ child] . terminator ( )
275258 else {
276259 return None ;
277260 } ;
278261 let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
279262 if child_ty != parent_ty {
280263 return None ;
281264 }
282- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
265+
266+ // We only handle:
267+ // ```
268+ // bb4: {
269+ // _8 = discriminant((_3.1: Enum1));
270+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
271+ // }
272+ // ```
273+ // and
274+ // ```
275+ // bb2: {
276+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
277+ // }
278+ // ```
279+ if bbs[ child] . statements . len ( ) > 1 {
283280 return None ;
281+ }
282+
283+ // When thie BB has exactly one statement, this statement should be discriminant.
284+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
285+ let child_place = if need_hoist_discriminant {
286+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
287+ // Someone could write code like this:
288+ // ```rust
289+ // let Q = val;
290+ // if discriminant(P) == otherwise {
291+ // let ptr = &mut Q as *mut _ as *mut u8;
292+ // // It may be difficult for us to effectively determine whether values are valid.
293+ // // Invalid values can come from all sorts of corners.
294+ // unsafe { *ptr = 10; }
295+ // }
296+ //
297+ // match P {
298+ // A => match Q {
299+ // A => {
300+ // // code
301+ // }
302+ // _ => {
303+ // // don't use Q
304+ // }
305+ // }
306+ // _ => {
307+ // // don't use Q
308+ // }
309+ // };
310+ // ```
311+ //
312+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313+ // invalid value, which is UB.
314+ // In order to fix this, **we would either need to show that the discriminant computation of
315+ // `place` is computed in all branches**.
316+ // FIXME(#95162) For the moment, we adopt a conservative approach and
317+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
318+ return None ;
319+ }
320+ // Handle:
321+ // ```
322+ // bb4: {
323+ // _8 = discriminant((_3.1: Enum1));
324+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
325+ // }
326+ // ```
327+ let [
328+ Statement {
329+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
330+ ..
331+ } ,
332+ ] = bbs[ child] . statements . as_slice ( )
333+ else {
334+ return None ;
335+ } ;
336+ * child_place
337+ } else {
338+ // Handle:
339+ // ```
340+ // bb2: {
341+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
342+ // }
343+ // ```
344+ let Operand :: Copy ( child_place) = child_discr else {
345+ return None ;
346+ } ;
347+ * child_place
284348 } ;
285- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
286- return None ;
349+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
350+ {
351+ child_targets. otherwise ( )
352+ } else {
353+ targets. otherwise ( )
287354 } ;
288- let destination = child_targets. otherwise ( ) ;
289355
290356 // Verify that the optimization is legal for each branch
291357 for ( value, child) in targets. iter ( ) {
292- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
358+ if !verify_candidate_branch (
359+ & bbs[ child] ,
360+ value,
361+ child_place,
362+ destination,
363+ need_hoist_discriminant,
364+ ) {
293365 return None ;
294366 }
295367 }
296368 Some ( OptimizationData {
297369 destination,
298- child_place : * child_place ,
370+ child_place,
299371 child_ty,
300- child_source : child_terminator. source_info ,
372+ child_source : * source_info,
373+ need_hoist_discriminant,
301374 } )
302375}
303376
@@ -306,45 +379,48 @@ fn verify_candidate_branch<'tcx>(
306379 value : u128 ,
307380 place : Place < ' tcx > ,
308381 destination : BasicBlock ,
382+ need_hoist_discriminant : bool ,
309383) -> bool {
310- // In order for the optimization to be correct, the branch must...
311- // ...have exactly one statement
312- let [ statement] = branch. statements . as_slice ( ) else {
313- return false ;
314- } ;
315- // ...assign the discriminant of `place` in that statement
316- let StatementKind :: Assign ( boxed) = & statement. kind else { return false } ;
317- let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
318- if * from_place != place {
319- return false ;
320- }
321- // ...make that assignment to a local
322- if discr_place. projection . len ( ) != 0 {
323- return false ;
324- }
325- // ...terminate on a `SwitchInt` that invalidates that local
326- let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } = & branch. terminator ( ) . kind
327- else {
384+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
385+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
328386 return false ;
329387 } ;
330- if * switch_op != Operand :: Move ( * discr_place) {
331- return false ;
388+ if need_hoist_discriminant {
389+ // If we need hoist discriminant, the branch must have exactly one statement.
390+ let [ statement] = branch. statements . as_slice ( ) else {
391+ return false ;
392+ } ;
393+ // The statement must assign the discriminant of `place`.
394+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
395+ statement. kind
396+ else {
397+ return false ;
398+ } ;
399+ if from_place != place {
400+ return false ;
401+ }
402+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
403+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
404+ return false ;
405+ }
406+ } else {
407+ // If we don't need hoist discriminant, the branch must not have any statements.
408+ if !branch. statements . is_empty ( ) {
409+ return false ;
410+ }
411+ // The place on `SwitchInt` must be the same.
412+ if * switch_op != Operand :: Copy ( place) {
413+ return false ;
414+ }
332415 }
333- // ... fall through to `destination` if the switch misses
416+ // It must fall through to `destination` if the switch misses.
334417 if destination != targets. otherwise ( ) {
335418 return false ;
336419 }
337- // ... have a branch for value `value`
420+ // It must have exactly one branch for value `value` and have no more branches.
338421 let mut iter = targets. iter ( ) ;
339- let Some ( ( target_value, _) ) = iter. next ( ) else {
422+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter . next ( ) ) else {
340423 return false ;
341424 } ;
342- if target_value != value {
343- return false ;
344- }
345- // ...and have no more branches
346- if let Some ( _) = iter. next ( ) {
347- return false ;
348- }
349- true
425+ target_value == value
350426}
0 commit comments