@@ -28,6 +28,113 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2828 }
2929}
3030
31+ fn match_args_from_caller_to_enzyme < ' ll > (
32+ cx : & SimpleCx < ' ll > ,
33+ args : & mut Vec < & ' ll llvm:: Value > ,
34+ inputs : & [ DiffActivity ] ,
35+ outer_args : & [ & ' ll llvm:: Value ] ,
36+ ) {
37+ debug ! ( "matching autodiff arguments" ) ;
38+ // We now handle the issue that Rust level arguments not always match the llvm-ir level
39+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
40+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
41+ // need to match those.
42+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
43+ // using iterators and peek()?
44+ let mut outer_pos: usize = 0 ;
45+ let mut activity_pos = 0 ;
46+
47+ let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
48+ let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
49+ let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
50+ let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
51+
52+ while activity_pos < inputs. len ( ) {
53+ let diff_activity = inputs[ activity_pos as usize ] ;
54+ // Duplicated arguments received a shadow argument, into which enzyme will write the
55+ // gradient.
56+ let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
57+ DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
58+ DiffActivity :: Const => ( enzyme_const, false ) ,
59+ DiffActivity :: Active => ( enzyme_out, false ) ,
60+ DiffActivity :: ActiveOnly => ( enzyme_out, false ) ,
61+ DiffActivity :: Dual => ( enzyme_dup, true ) ,
62+ DiffActivity :: DualOnly => ( enzyme_dupnoneed, true ) ,
63+ DiffActivity :: Duplicated => ( enzyme_dup, true ) ,
64+ DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed, true ) ,
65+ DiffActivity :: FakeActivitySize => ( enzyme_const, false ) ,
66+ } ;
67+ let outer_arg = outer_args[ outer_pos] ;
68+ args. push ( cx. get_metadata_value ( activity) ) ;
69+ args. push ( outer_arg) ;
70+ if duplicated {
71+ // We know that duplicated args by construction have a following argument,
72+ // so this can not be out of bounds.
73+ let next_outer_arg = outer_args[ outer_pos + 1 ] ;
74+ let next_outer_ty = cx. val_ty ( next_outer_arg) ;
75+ // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
76+ // vectors behind references (&Vec<T>) are already supported. Users can not pass a
77+ // Vec by value for reverse mode, so this would only help forward mode autodiff.
78+ let slice = {
79+ if activity_pos + 1 >= inputs. len ( ) {
80+ // If there is no arg following our ptr, it also can't be a slice,
81+ // since that would lead to a ptr, int pair.
82+ false
83+ } else {
84+ let next_activity = inputs[ activity_pos + 1 ] ;
85+ // We analyze the MIR types and add this dummy activity if we visit a slice.
86+ next_activity == DiffActivity :: FakeActivitySize
87+ }
88+ } ;
89+ if slice {
90+ // A duplicated slice will have the following two outer_fn arguments:
91+ // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
92+ // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
93+ // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
94+ // int2 >= int1, which means the shadow vector is large enough to store the gradient.
95+ assert ! ( unsafe {
96+ llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
97+ } ) ;
98+ let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
99+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
100+ assert ! ( unsafe {
101+ llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
102+ } ) ;
103+ let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
104+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
105+ assert ! ( unsafe {
106+ llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
107+ } ) ;
108+ args. push ( next_outer_arg2) ;
109+ args. push ( cx. get_metadata_value ( enzyme_const) ) ;
110+ args. push ( next_outer_arg) ;
111+ outer_pos += 4 ;
112+ activity_pos += 2 ;
113+ } else {
114+ // A duplicated pointer will have the following two outer_fn arguments:
115+ // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
116+ // (..., metadata! enzyme_dup, ptr, ptr, ...).
117+ if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly )
118+ {
119+ assert ! (
120+ unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty) }
121+ == llvm:: TypeKind :: Pointer
122+ ) ;
123+ }
124+ // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
125+ args. push ( next_outer_arg) ;
126+ outer_pos += 2 ;
127+ activity_pos += 1 ;
128+ }
129+ } else {
130+ // We do not differentiate with resprect to this argument.
131+ // We already added the metadata and argument above, so just increase the counters.
132+ outer_pos += 1 ;
133+ activity_pos += 1 ;
134+ }
135+ }
136+ }
137+
31138/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
32139/// function with expected naming and calling conventions[^1] which will be
33140/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -132,12 +239,7 @@ fn generate_enzyme_call<'ll>(
132239 let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
133240 args. push ( fn_to_diff) ;
134241
135- let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
136- let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
137- let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
138- let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
139242 let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
140-
141243 match output {
142244 DiffActivity :: Dual => {
143245 args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
@@ -148,95 +250,8 @@ fn generate_enzyme_call<'ll>(
148250 _ => { }
149251 }
150252
151- debug ! ( "matching autodiff arguments" ) ;
152- // We now handle the issue that Rust level arguments not always match the llvm-ir level
153- // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
154- // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
155- // need to match those.
156- // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
157- // using iterators and peek()?
158- let mut outer_pos: usize = 0 ;
159- let mut activity_pos = 0 ;
160253 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
161- while activity_pos < inputs. len ( ) {
162- let diff_activity = inputs[ activity_pos as usize ] ;
163- // Duplicated arguments received a shadow argument, into which enzyme will write the
164- // gradient.
165- let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
166- DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
167- DiffActivity :: Const => ( enzyme_const, false ) ,
168- DiffActivity :: Active => ( enzyme_out, false ) ,
169- DiffActivity :: ActiveOnly => ( enzyme_out, false ) ,
170- DiffActivity :: Dual => ( enzyme_dup, true ) ,
171- DiffActivity :: DualOnly => ( enzyme_dupnoneed, true ) ,
172- DiffActivity :: Duplicated => ( enzyme_dup, true ) ,
173- DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed, true ) ,
174- DiffActivity :: FakeActivitySize => ( enzyme_const, false ) ,
175- } ;
176- let outer_arg = outer_args[ outer_pos] ;
177- args. push ( cx. get_metadata_value ( activity) ) ;
178- args. push ( outer_arg) ;
179- if duplicated {
180- // We know that duplicated args by construction have a following argument,
181- // so this can not be out of bounds.
182- let next_outer_arg = outer_args[ outer_pos + 1 ] ;
183- let next_outer_ty = cx. val_ty ( next_outer_arg) ;
184- // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
185- // vectors behind references (&Vec<T>) are already supported. Users can not pass a
186- // Vec by value for reverse mode, so this would only help forward mode autodiff.
187- let slice = {
188- if activity_pos + 1 >= inputs. len ( ) {
189- // If there is no arg following our ptr, it also can't be a slice,
190- // since that would lead to a ptr, int pair.
191- false
192- } else {
193- let next_activity = inputs[ activity_pos + 1 ] ;
194- // We analyze the MIR types and add this dummy activity if we visit a slice.
195- next_activity == DiffActivity :: FakeActivitySize
196- }
197- } ;
198- if slice {
199- // A duplicated slice will have the following two outer_fn arguments:
200- // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
201- // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
202- // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
203- // int2 >= int1, which means the shadow vector is large enough to store the gradient.
204- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer ) ;
205- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
206- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
207- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer ) ;
208- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
209- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
210- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer ) ;
211- args. push ( next_outer_arg2) ;
212- args. push ( cx. get_metadata_value ( enzyme_const) ) ;
213- args. push ( next_outer_arg) ;
214- outer_pos += 4 ;
215- activity_pos += 2 ;
216- } else {
217- // A duplicated pointer will have the following two outer_fn arguments:
218- // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
219- // (..., metadata! enzyme_dup, ptr, ptr, ...).
220- if matches ! (
221- diff_activity,
222- DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly
223- ) {
224- assert ! (
225- llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer
226- ) ;
227- }
228- // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
229- args. push ( next_outer_arg) ;
230- outer_pos += 2 ;
231- activity_pos += 1 ;
232- }
233- } else {
234- // We do not differentiate with resprect to this argument.
235- // We already added the metadata and argument above, so just increase the counters.
236- outer_pos += 1 ;
237- activity_pos += 1 ;
238- }
239- }
254+ match_args_from_caller_to_enzyme ( & cx, & mut args, & inputs, & outer_args) ;
240255
241256 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
242257
0 commit comments