@@ -86,27 +86,23 @@ mod llvm_enzyme {
8686 ecx : & mut ExtCtxt < ' _ > ,
8787 meta_item : & ThinVec < MetaItemInner > ,
8888 has_ret : bool ,
89+ mode : DiffMode ,
8990 ) -> AutoDiffAttrs {
9091 let dcx = ecx. sess . dcx ( ) ;
91- let mode = name ( & meta_item[ 1 ] ) ;
92- let Ok ( mode) = DiffMode :: from_str ( & mode) else {
93- dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
94- return AutoDiffAttrs :: error ( ) ;
95- } ;
9692
9793 // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
9894 // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99- let mut first_activity = 2 ;
95+ let mut first_activity = 1 ;
10096
101- let width = if let [ _, _ , x, ..] = & meta_item[ ..]
97+ let width = if let [ _, x, ..] = & meta_item[ ..]
10298 && let Some ( x) = width ( x)
10399 {
104- first_activity = 3 ;
100+ first_activity = 2 ;
105101 match x. try_into ( ) {
106102 Ok ( x) => x,
107103 Err ( _) => {
108104 dcx. emit_err ( errors:: AutoDiffInvalidWidth {
109- span : meta_item[ 2 ] . span ( ) ,
105+ span : meta_item[ 1 ] . span ( ) ,
110106 width : x,
111107 } ) ;
112108 return AutoDiffAttrs :: error ( ) ;
@@ -165,6 +161,24 @@ mod llvm_enzyme {
165161 ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
166162 }
167163
164+ pub ( crate ) fn expand_forward (
165+ ecx : & mut ExtCtxt < ' _ > ,
166+ expand_span : Span ,
167+ meta_item : & ast:: MetaItem ,
168+ item : Annotatable ,
169+ ) -> Vec < Annotatable > {
170+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Forward )
171+ }
172+
173+ pub ( crate ) fn expand_reverse (
174+ ecx : & mut ExtCtxt < ' _ > ,
175+ expand_span : Span ,
176+ meta_item : & ast:: MetaItem ,
177+ item : Annotatable ,
178+ ) -> Vec < Annotatable > {
179+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Reverse )
180+ }
181+
168182 /// We expand the autodiff macro to generate a new placeholder function which passes
169183 /// type-checking and can be called by users. The function body of the placeholder function will
170184 /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +212,12 @@ mod llvm_enzyme {
198212 /// ```
199213 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200214 /// in CI.
201- pub ( crate ) fn expand (
215+ pub ( crate ) fn expand_with_mode (
202216 ecx : & mut ExtCtxt < ' _ > ,
203217 expand_span : Span ,
204218 meta_item : & ast:: MetaItem ,
205219 mut item : Annotatable ,
220+ mode : DiffMode ,
206221 ) -> Vec < Annotatable > {
207222 if cfg ! ( not( llvm_enzyme) ) {
208223 ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
@@ -245,29 +260,41 @@ mod llvm_enzyme {
245260 // create TokenStream from vec elemtents:
246261 // meta_item doesn't have a .tokens field
247262 let mut ts: Vec < TokenTree > = vec ! [ ] ;
248- if meta_item_vec. len ( ) < 2 {
249- // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
250- // input and output args.
263+ if meta_item_vec. len ( ) < 1 {
264+ // At the bare minimum, we need a fnc name.
251265 dcx. emit_err ( errors:: AutoDiffMissingConfig { span : item. span ( ) } ) ;
252266 return vec ! [ item] ;
253267 }
254268
255- meta_item_inner_to_ts ( & meta_item_vec[ 1 ] , & mut ts) ;
269+ let mode_symbol = match mode {
270+ DiffMode :: Forward => sym:: Forward ,
271+ DiffMode :: Reverse => sym:: Reverse ,
272+ _ => unreachable ! ( "Unsupported mode: {:?}" , mode) ,
273+ } ;
274+
275+ // Insert mode token
276+ let mode_token = Token :: new ( TokenKind :: Ident ( mode_symbol, false . into ( ) ) , Span :: default ( ) ) ;
277+ ts. insert ( 0 , TokenTree :: Token ( mode_token, Spacing :: Joint ) ) ;
278+ ts. insert (
279+ 1 ,
280+ TokenTree :: Token ( Token :: new ( TokenKind :: Comma , Span :: default ( ) ) , Spacing :: Alone ) ,
281+ ) ;
256282
257283 // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
258284 // If it is not given, we default to 1 (scalar mode).
259285 let start_position;
260286 let kind: LitKind = LitKind :: Integer ;
261287 let symbol;
262- if meta_item_vec. len ( ) >= 3
263- && let Some ( width) = width ( & meta_item_vec[ 2 ] )
288+ if meta_item_vec. len ( ) >= 2
289+ && let Some ( width) = width ( & meta_item_vec[ 1 ] )
264290 {
265- start_position = 3 ;
291+ start_position = 2 ;
266292 symbol = Symbol :: intern ( & width. to_string ( ) ) ;
267293 } else {
268- start_position = 2 ;
294+ start_position = 1 ;
269295 symbol = sym:: integer ( 1 ) ;
270296 }
297+
271298 let l: Lit = Lit { kind, symbol, suffix : None } ;
272299 let t = Token :: new ( TokenKind :: Literal ( l) , Span :: default ( ) ) ;
273300 let comma = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
@@ -289,7 +316,7 @@ mod llvm_enzyme {
289316 ts. pop ( ) ;
290317 let ts: TokenStream = TokenStream :: from_iter ( ts) ;
291318
292- let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
319+ let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret, mode ) ;
293320 if !x. is_active ( ) {
294321 // We encountered an error, so we return the original item.
295322 // This allows us to potentially parse other attributes.
@@ -1017,4 +1044,4 @@ mod llvm_enzyme {
10171044 }
10181045}
10191046
1020- pub ( crate ) use llvm_enzyme:: expand ;
1047+ pub ( crate ) use llvm_enzyme:: { expand_forward , expand_reverse } ;
0 commit comments