@@ -98,7 +98,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
9898 fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
9999 trace ! ( "running EarlyOtherwiseBranch on {:?}" , body. source) ;
100100
101- let mut should_cleanup = false ;
101+ let mut should_apply_patch = false ;
102+ let mut patch = MirPatch :: new ( body) ;
102103
103104 // Also consider newly generated bbs in the same pass
104105 for i in 0 ..body. basic_blocks . len ( ) {
@@ -112,7 +113,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
112113
113114 trace ! ( "SUCCESS: found optimization possibility to apply: {:?}" , & opt_data) ;
114115
115- should_cleanup = true ;
116+ should_apply_patch = true ;
116117
117118 let TerminatorKind :: SwitchInt { discr : parent_op, targets : parent_targets } =
118119 & bbs[ parent] . terminator ( ) . kind
@@ -129,8 +130,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
129130 let statements_before = bbs[ parent] . statements . len ( ) ;
130131 let parent_end = Location { block : parent, statement_index : statements_before } ;
131132
132- let mut patch = MirPatch :: new ( body) ;
133-
134133 let ( second_discriminant_temp, second_operand) = if opt_data. hoist_discriminant {
135134 // create temp to store second discriminant in, `_s` in example above
136135 let second_discriminant_temp =
@@ -242,13 +241,12 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
242241 ) ;
243242 }
244243 }
245-
246- patch. apply ( body) ;
247244 }
248245
249246 // Since this optimization adds new basic blocks and invalidates others,
250247 // clean up the cfg to make it nicer for other passes
251- if should_cleanup {
248+ if should_apply_patch {
249+ patch. apply ( body) ;
252250 simplify_cfg ( body) ;
253251 }
254252 }
@@ -275,19 +273,15 @@ fn evaluate_candidate<'tcx>(
275273 return None ;
276274 } ;
277275 let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
278- let ( _, child) = targets. iter ( ) . next ( ) ?;
279- let child_terminator = & bbs[ child] . terminator ( ) ;
280- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
281- & child_terminator. kind
276+ let mut targets_iter = targets. iter ( ) ;
277+ let ( _, first_child) = targets_iter. next ( ) ?;
278+ let first_child_terminator = & bbs[ first_child] . terminator ( ) ;
279+ let TerminatorKind :: SwitchInt { targets : first_child_targets, discr : first_child_discr } =
280+ & first_child_terminator. kind
282281 else {
283282 return None ;
284283 } ;
285- let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
286- if bbs[ child] . statements . len ( ) > 1 {
287- return None ;
288- }
289- let hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
290- let child_place = if hoist_discriminant {
284+ let hoist_discriminant = if bbs[ first_child] . statements . len ( ) == 1 {
291285 if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
292286 // Someone could write code like this:
293287 // ```rust
@@ -320,7 +314,44 @@ fn evaluate_candidate<'tcx>(
320314 // So we need the `otherwise` branch has no statements and an unreachable terminator.
321315 return None ;
322316 }
323- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind )
317+ true
318+ } else if bbs[ first_child] . statements . is_empty ( ) {
319+ false
320+ } else {
321+ return None ;
322+ } ;
323+ let destination = if hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
324+ first_child_targets. otherwise ( )
325+ } else {
326+ if first_child_targets. otherwise ( ) != targets. otherwise ( ) {
327+ return None ;
328+ }
329+ targets. otherwise ( )
330+ } ;
331+ while let Some ( ( _, child) ) = targets_iter. next ( ) {
332+ let child_branch = & bbs[ child] ;
333+ // In order for the optimization to be correct, the branch must...
334+ // ...have exactly one or empty statement
335+ if ( hoist_discriminant && child_branch. statements . len ( ) != 1 )
336+ || ( !hoist_discriminant && !child_branch. statements . is_empty ( ) )
337+ {
338+ return None ;
339+ }
340+ // ...terminate on a `SwitchInt` that invalidates that local
341+ let TerminatorKind :: SwitchInt { targets : child_targets, .. } =
342+ & child_branch. terminator ( ) . kind
343+ else {
344+ return None ;
345+ } ;
346+ if child_targets. otherwise ( ) != destination {
347+ return None ;
348+ }
349+ // Make sure there are only two branches.
350+ }
351+ let child_ty = first_child_discr. ty ( body. local_decls ( ) , tcx) ;
352+ let child_place = if hoist_discriminant {
353+ let Some ( StatementKind :: Assign ( boxed) ) =
354+ & bbs[ first_child] . statements . first ( ) . map ( |x| & x. kind )
324355 else {
325356 return None ;
326357 } ;
@@ -329,26 +360,17 @@ fn evaluate_candidate<'tcx>(
329360 } ;
330361 * child_place
331362 } else {
332- let TerminatorKind :: SwitchInt { discr, .. } = & bbs[ child ] . terminator ( ) . kind else {
363+ let TerminatorKind :: SwitchInt { discr, .. } = & bbs[ first_child ] . terminator ( ) . kind else {
333364 return None ;
334365 } ;
335366 let Operand :: Copy ( child_place) = discr else {
336367 return None ;
337368 } ;
338369 * child_place
339370 } ;
340- let destination = if hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
341- child_targets. otherwise ( )
342- } else {
343- targets. otherwise ( )
344- } ;
345371
346- let TerminatorKind :: SwitchInt { targets : child_targets, .. } = & bbs[ child] . terminator ( ) . kind
347- else {
348- return None ;
349- } ;
350372 // Verify that the optimization is legal for each branch
351- let Some ( ( may_same_target_value, _) ) = child_targets . iter ( ) . next ( ) else {
373+ let Some ( ( may_same_target_value, _) ) = first_child_targets . iter ( ) . next ( ) else {
352374 return None ;
353375 } ;
354376 let mut same_target_value = Some ( may_same_target_value) ;
@@ -357,7 +379,6 @@ fn evaluate_candidate<'tcx>(
357379 & bbs[ child] ,
358380 may_same_target_value,
359381 child_place,
360- destination,
361382 hoist_discriminant,
362383 ) {
363384 same_target_value = None ;
@@ -369,13 +390,7 @@ fn evaluate_candidate<'tcx>(
369390 return None ;
370391 }
371392 for ( value, child) in targets. iter ( ) {
372- if !verify_candidate_branch (
373- & bbs[ child] ,
374- value,
375- child_place,
376- destination,
377- hoist_discriminant,
378- ) {
393+ if !verify_candidate_branch ( & bbs[ child] , value, child_place, hoist_discriminant) {
379394 return None ;
380395 }
381396 }
@@ -384,7 +399,7 @@ fn evaluate_candidate<'tcx>(
384399 destination,
385400 child_place,
386401 child_ty,
387- child_source : child_terminator . source_info ,
402+ child_source : first_child_terminator . source_info ,
388403 hoist_discriminant,
389404 same_target_value,
390405 } )
@@ -394,20 +409,11 @@ fn verify_candidate_branch<'tcx>(
394409 branch : & BasicBlockData < ' tcx > ,
395410 value : u128 ,
396411 place : Place < ' tcx > ,
397- destination : BasicBlock ,
398412 hoist_discriminant : bool ,
399413) -> bool {
400- // In order for the optimization to be correct, the branch must...
401- // ...have exactly one statement
402- if ( hoist_discriminant && branch. statements . len ( ) != 1 )
403- || ( !hoist_discriminant && !branch. statements . is_empty ( ) )
404- {
405- return false ;
406- }
407- // ...terminate on a `SwitchInt` that invalidates that local
408414 let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } = & branch. terminator ( ) . kind
409415 else {
410- return false ;
416+ unreachable ! ( )
411417 } ;
412418 if hoist_discriminant {
413419 // ...assign the discriminant of `place` in that statement
@@ -428,10 +434,6 @@ fn verify_candidate_branch<'tcx>(
428434 return false ;
429435 }
430436 }
431- // ...fall through to `destination` if the switch misses
432- if destination != targets. otherwise ( ) {
433- return false ;
434- }
435437 // ...have a branch for value `value`
436438 let mut iter = targets. iter ( ) ;
437439 let Some ( ( target_value, _) ) = iter. next ( ) else {
0 commit comments