@@ -13,15 +13,19 @@ use crate::transform::{simplify, MirPass, MirSource};
1313use itertools:: Itertools as _;
1414use rustc:: mir:: * ;
1515use rustc:: ty:: { Ty , TyCtxt } ;
16+ use rustc_index:: vec:: IndexVec ;
1617use rustc_target:: abi:: VariantIdx ;
18+ use std:: iter:: { Enumerate , Peekable } ;
19+ use std:: slice:: Iter ;
1720
1821/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
1922///
2023/// This is done by transforming basic blocks where the statements match:
2124///
2225/// ```rust
2326/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
24- /// ((_LOCAL_0 as Variant).FIELD: TY) = move _LOCAL_TMP;
27+ /// _TMP_2 = _LOCAL_TMP;
28+ /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
2529/// discriminant(_LOCAL_0) = VAR_IDX;
2630/// ```
2731///
@@ -32,50 +36,306 @@ use rustc_target::abi::VariantIdx;
3236/// ```
3337pub struct SimplifyArmIdentity ;
3438
39+ #[ derive( Debug ) ]
40+ struct ArmIdentityInfo < ' tcx > {
41+ /// Storage location for the variant's field
42+ local_temp_0 : Local ,
43+ /// Storage location holding the variant being read from
44+ local_1 : Local ,
45+ /// The variant field being read from
46+ vf_s0 : VarField < ' tcx > ,
47+ /// Index of the statement which loads the variant being read
48+ get_variant_field_stmt : usize ,
49+
50+ /// Tracks each assignment to a temporary of the variant's field
51+ field_tmp_assignments : Vec < ( Local , Local ) > ,
52+
53+ /// Storage location holding the variant's field that was read from
54+ local_tmp_s1 : Local ,
55+ /// Storage location holding the enum that we are writing to
56+ local_0 : Local ,
57+ /// The variant field being written to
58+ vf_s1 : VarField < ' tcx > ,
59+
60+ /// Storage location that the discriminant is being written to
61+ set_discr_local : Local ,
62+ /// The variant being written
63+ set_discr_var_idx : VariantIdx ,
64+
65+ /// Index of the statement that should be overwritten as a move
66+ stmt_to_overwrite : usize ,
67+ /// SourceInfo for the new move
68+ source_info : SourceInfo ,
69+
70+ /// Indices of matching Storage{Live,Dead} statements encountered.
71+ /// (StorageLive index,, StorageDead index, Local)
72+ storage_stmts : Vec < ( usize , usize , Local ) > ,
73+
74+ /// The statements that should be removed (turned into nops)
75+ stmts_to_remove : Vec < usize > ,
76+ }
77+
78+ fn get_arm_identity_info < ' a , ' tcx > ( stmts : & ' a [ Statement < ' tcx > ] ) -> Option < ArmIdentityInfo < ' tcx > > {
79+ // This can't possibly match unless there are at least 3 statements in the block
80+ // so fail fast on tiny blocks.
81+ if stmts. len ( ) < 3 {
82+ return None ;
83+ }
84+
85+ let mut tmp_assigns = Vec :: new ( ) ;
86+ let mut nop_stmts = Vec :: new ( ) ;
87+ let mut storage_stmts = Vec :: new ( ) ;
88+ let mut storage_live_stmts = Vec :: new ( ) ;
89+ let mut storage_dead_stmts = Vec :: new ( ) ;
90+
91+ type StmtIter < ' a , ' tcx > = Peekable < Enumerate < Iter < ' a , Statement < ' tcx > > > > ;
92+
93+ fn is_storage_stmt < ' tcx > ( stmt : & Statement < ' tcx > ) -> bool {
94+ matches ! ( stmt. kind, StatementKind :: StorageLive ( _) | StatementKind :: StorageDead ( _) )
95+ }
96+
97+ fn try_eat_storage_stmts < ' a , ' tcx > (
98+ stmt_iter : & mut StmtIter < ' a , ' tcx > ,
99+ storage_live_stmts : & mut Vec < ( usize , Local ) > ,
100+ storage_dead_stmts : & mut Vec < ( usize , Local ) > ,
101+ ) {
102+ while stmt_iter. peek ( ) . map ( |( _, stmt) | is_storage_stmt ( stmt) ) . unwrap_or ( false ) {
103+ let ( idx, stmt) = stmt_iter. next ( ) . unwrap ( ) ;
104+
105+ if let StatementKind :: StorageLive ( l) = stmt. kind {
106+ storage_live_stmts. push ( ( idx, l) ) ;
107+ } else if let StatementKind :: StorageDead ( l) = stmt. kind {
108+ storage_dead_stmts. push ( ( idx, l) ) ;
109+ }
110+ }
111+ }
112+
113+ fn is_tmp_storage_stmt < ' tcx > ( stmt : & Statement < ' tcx > ) -> bool {
114+ if let StatementKind :: Assign ( box ( place, Rvalue :: Use ( op) ) ) = & stmt. kind {
115+ if let Operand :: Copy ( p) | Operand :: Move ( p) = op {
116+ return place. as_local ( ) . is_some ( ) && p. as_local ( ) . is_some ( ) ;
117+ }
118+ }
119+
120+ false
121+ }
122+
123+ fn try_eat_assign_tmp_stmts < ' a , ' tcx > (
124+ stmt_iter : & mut StmtIter < ' a , ' tcx > ,
125+ tmp_assigns : & mut Vec < ( Local , Local ) > ,
126+ nop_stmts : & mut Vec < usize > ,
127+ ) {
128+ while stmt_iter. peek ( ) . map ( |( _, stmt) | is_tmp_storage_stmt ( stmt) ) . unwrap_or ( false ) {
129+ let ( idx, stmt) = stmt_iter. next ( ) . unwrap ( ) ;
130+
131+ if let StatementKind :: Assign ( box ( place, Rvalue :: Use ( op) ) ) = & stmt. kind {
132+ if let Operand :: Copy ( p) | Operand :: Move ( p) = op {
133+ tmp_assigns. push ( ( place. as_local ( ) . unwrap ( ) , p. as_local ( ) . unwrap ( ) ) ) ;
134+ nop_stmts. push ( idx) ;
135+ }
136+ }
137+ }
138+ }
139+
140+ fn find_storage_live_dead_stmts_for_local < ' tcx > (
141+ l : Local ,
142+ stmts : & [ Statement < ' tcx > ] ,
143+ ) -> Option < ( usize , usize ) > {
144+ trace ! ( "looking for {:?}" , l) ;
145+ let mut storage_live_stmt = None ;
146+ let mut storage_dead_stmt = None ;
147+ for ( idx, stmt) in stmts. iter ( ) . enumerate ( ) {
148+ if stmt. kind == StatementKind :: StorageLive ( l) {
149+ storage_live_stmt = Some ( idx) ;
150+ } else if stmt. kind == StatementKind :: StorageDead ( l) {
151+ storage_dead_stmt = Some ( idx) ;
152+ }
153+ }
154+
155+ Some ( ( storage_live_stmt?, storage_dead_stmt. unwrap_or ( usize:: MAX ) ) )
156+ }
157+
158+ // Try to match the expected MIR structure with the basic block we're processing.
159+ // We want to see something that looks like:
160+ // ```
161+ // (StorageLive(_) | StorageDead(_));*
162+ // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
163+ // (StorageLive(_) | StorageDead(_));*
164+ // (tmp_n+1 = tmp_n);*
165+ // (StorageLive(_) | StorageDead(_));*
166+ // (tmp_n+1 = tmp_n);*
167+ // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
168+ // discriminant(LOCAL_FROM) = VariantIdx;
169+ // (StorageLive(_) | StorageDead(_));*
170+ // ```
171+ let mut stmt_iter = stmts. iter ( ) . enumerate ( ) . peekable ( ) ;
172+
173+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
174+
175+ let ( get_variant_field_stmt, stmt) = stmt_iter. next ( ) ?;
176+ let ( local_tmp_s0, local_1, vf_s0) = match_get_variant_field ( stmt) ?;
177+
178+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
179+
180+ try_eat_assign_tmp_stmts ( & mut stmt_iter, & mut tmp_assigns, & mut nop_stmts) ;
181+
182+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
183+
184+ try_eat_assign_tmp_stmts ( & mut stmt_iter, & mut tmp_assigns, & mut nop_stmts) ;
185+
186+ let ( idx, stmt) = stmt_iter. next ( ) ?;
187+ let ( local_tmp_s1, local_0, vf_s1) = match_set_variant_field ( stmt) ?;
188+ nop_stmts. push ( idx) ;
189+
190+ let ( idx, stmt) = stmt_iter. next ( ) ?;
191+ let ( set_discr_local, set_discr_var_idx) = match_set_discr ( stmt) ?;
192+ let discr_stmt_source_info = stmt. source_info ;
193+ nop_stmts. push ( idx) ;
194+
195+ try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
196+
197+ for ( live_idx, live_local) in storage_live_stmts {
198+ if let Some ( i) = storage_dead_stmts. iter ( ) . rposition ( |( _, l) | * l == live_local) {
199+ let ( dead_idx, _) = storage_dead_stmts. swap_remove ( i) ;
200+ storage_stmts. push ( ( live_idx, dead_idx, live_local) ) ;
201+
202+ if live_local == local_tmp_s0 {
203+ nop_stmts. push ( get_variant_field_stmt) ;
204+ }
205+ }
206+ }
207+
208+ nop_stmts. sort ( ) ;
209+
210+ // Use one of the statements we're going to discard between the point
211+ // where the storage location for the variant field becomes live and
212+ // is killed.
213+ let ( live_idx, daed_idx) = find_storage_live_dead_stmts_for_local ( local_tmp_s0, stmts) ?;
214+ let stmt_to_overwrite =
215+ nop_stmts. iter ( ) . find ( |stmt_idx| live_idx < * * stmt_idx && * * stmt_idx < daed_idx) ;
216+
217+ Some ( ArmIdentityInfo {
218+ local_temp_0 : local_tmp_s0,
219+ local_1,
220+ vf_s0,
221+ get_variant_field_stmt,
222+ field_tmp_assignments : tmp_assigns,
223+ local_tmp_s1,
224+ local_0,
225+ vf_s1,
226+ set_discr_local,
227+ set_discr_var_idx,
228+ stmt_to_overwrite : * stmt_to_overwrite?,
229+ source_info : discr_stmt_source_info,
230+ storage_stmts,
231+ stmts_to_remove : nop_stmts,
232+ } )
233+ }
234+
235+ fn optimization_applies < ' tcx > (
236+ opt_info : & ArmIdentityInfo < ' tcx > ,
237+ local_decls : & IndexVec < Local , LocalDecl < ' tcx > > ,
238+ ) -> bool {
239+ trace ! ( "testing if optimization applies..." ) ;
240+
241+ // FIXME(wesleywiser): possible relax this restriction?
242+ if opt_info. local_0 == opt_info. local_1 {
243+ trace ! ( "NO: moving into ourselves" ) ;
244+ return false ;
245+ } else if opt_info. vf_s0 != opt_info. vf_s1 {
246+ trace ! ( "NO: the field-and-variant information do not match" ) ;
247+ return false ;
248+ } else if local_decls[ opt_info. local_0 ] . ty != local_decls[ opt_info. local_1 ] . ty {
249+ // FIXME(Centril,oli-obk): possibly relax to same layout?
250+ trace ! ( "NO: source and target locals have different types" ) ;
251+ return false ;
252+ } else if ( opt_info. local_0 , opt_info. vf_s0 . var_idx )
253+ != ( opt_info. set_discr_local , opt_info. set_discr_var_idx )
254+ {
255+ trace ! ( "NO: the discriminants do not match" ) ;
256+ return false ;
257+ }
258+
259+ // Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
260+ if opt_info. field_tmp_assignments . len ( ) == 0 {
261+ trace ! ( "NO: no assignments found" ) ;
262+ }
263+ let mut last_assigned_to = opt_info. field_tmp_assignments [ 0 ] . 1 ;
264+ let source_local = last_assigned_to;
265+ for ( l, r) in & opt_info. field_tmp_assignments {
266+ if * r != last_assigned_to {
267+ trace ! ( "NO: found unexpected assignment {:?} = {:?}" , l, r) ;
268+ return false ;
269+ }
270+
271+ last_assigned_to = * l;
272+ }
273+
274+ if source_local != opt_info. local_temp_0 {
275+ trace ! (
276+ "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}" ,
277+ source_local,
278+ opt_info. local_temp_0
279+ ) ;
280+ return false ;
281+ } else if last_assigned_to != opt_info. local_tmp_s1 {
282+ trace ! (
283+ "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}" ,
284+ last_assigned_to,
285+ opt_info. local_tmp_s1
286+ ) ;
287+ return false ;
288+ }
289+
290+ trace ! ( "SUCCESS: optimization applies!" ) ;
291+ return true ;
292+ }
293+
35294impl < ' tcx > MirPass < ' tcx > for SimplifyArmIdentity {
36- fn run_pass ( & self , _: TyCtxt < ' tcx > , _: MirSource < ' tcx > , body : & mut BodyAndCache < ' tcx > ) {
295+ fn run_pass ( & self , _: TyCtxt < ' tcx > , source : MirSource < ' tcx > , body : & mut BodyAndCache < ' tcx > ) {
296+ trace ! ( "running SimplifyArmIdentity on {:?}" , source) ;
37297 let ( basic_blocks, local_decls) = body. basic_blocks_and_local_decls_mut ( ) ;
38298 for bb in basic_blocks {
39- // Need 3 statements:
40- let ( s0, s1, s2) = match & mut * bb. statements {
41- [ s0, s1, s2] => ( s0, s1, s2) ,
42- _ => continue ,
43- } ;
299+ if let Some ( opt_info) = get_arm_identity_info ( & bb. statements ) {
300+ trace ! ( "got opt_info = {:#?}" , opt_info) ;
301+ if !optimization_applies ( & opt_info, local_decls) {
302+ debug ! ( "optimization skipped for {:?}" , source) ;
303+ continue ;
304+ }
44305
45- // Pattern match on the form we want:
46- let ( local_tmp_s0, local_1, vf_s0) = match match_get_variant_field ( s0) {
47- None => continue ,
48- Some ( x) => x,
49- } ;
50- let ( local_tmp_s1, local_0, vf_s1) = match match_set_variant_field ( s1) {
51- None => continue ,
52- Some ( x) => x,
53- } ;
54- if local_tmp_s0 != local_tmp_s1
55- // Avoid moving into ourselves.
56- || local_0 == local_1
57- // The field-and-variant information match up.
58- || vf_s0 != vf_s1
59- // Source and target locals have the same type.
60- // FIXME(Centril | oli-obk): possibly relax to same layout?
61- || local_decls[ local_0] . ty != local_decls[ local_1] . ty
62- // We're setting the discriminant of `local_0` to this variant.
63- || Some ( ( local_0, vf_s0. var_idx ) ) != match_set_discr ( s2)
64- {
65- continue ;
66- }
306+ // Also remove unused Storage{Live,Dead} statements which correspond
307+ // to temps used previously.
308+ for ( live_idx, dead_idx, local) in & opt_info. storage_stmts {
309+ // The temporary that we've read the variant field into is scoped to this block,
310+ // so we can remove the assignment.
311+ if * local == opt_info. local_temp_0 {
312+ bb. statements [ opt_info. get_variant_field_stmt ] . make_nop ( ) ;
313+ }
67314
68- // Right shape; transform!
69- s0 . source_info = s2 . source_info ;
70- match & mut s0 . kind {
71- StatementKind :: Assign ( box ( place , rvalue ) ) => {
72- * place = local_0 . into ( ) ;
73- * rvalue = Rvalue :: Use ( Operand :: Move ( local_1 . into ( ) ) ) ;
315+ for ( left , right ) in & opt_info . field_tmp_assignments {
316+ if local == left || local == right {
317+ bb . statements [ * live_idx ] . make_nop ( ) ;
318+ bb . statements [ * dead_idx ] . make_nop ( ) ;
319+ }
320+ }
74321 }
75- _ => unreachable ! ( ) ,
322+
323+ // Right shape; transform
324+ for stmt_idx in opt_info. stmts_to_remove {
325+ bb. statements [ stmt_idx] . make_nop ( ) ;
326+ }
327+
328+ let stmt = & mut bb. statements [ opt_info. stmt_to_overwrite ] ;
329+ stmt. source_info = opt_info. source_info ;
330+ stmt. kind = StatementKind :: Assign ( box (
331+ opt_info. local_0 . into ( ) ,
332+ Rvalue :: Use ( Operand :: Move ( opt_info. local_1 . into ( ) ) ) ,
333+ ) ) ;
334+
335+ bb. statements . retain ( |stmt| stmt. kind != StatementKind :: Nop ) ;
336+
337+ trace ! ( "block is now {:?}" , bb. statements) ;
76338 }
77- s1. make_nop ( ) ;
78- s2. make_nop ( ) ;
79339 }
80340 }
81341}
@@ -129,7 +389,7 @@ fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)>
129389 }
130390}
131391
132- #[ derive( PartialEq ) ]
392+ #[ derive( PartialEq , Debug ) ]
133393struct VarField < ' tcx > {
134394 field : Field ,
135395 field_ty : Ty < ' tcx > ,
0 commit comments