@@ -88,25 +88,20 @@ mod llvm_enzyme {
8888 has_ret : bool ,
8989 ) -> AutoDiffAttrs {
9090 let dcx = ecx. sess . dcx ( ) ;
91- let mode = name ( & meta_item[ 1 ] ) ;
92- let Ok ( mode) = DiffMode :: from_str ( & mode) else {
93- dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
94- return AutoDiffAttrs :: error ( ) ;
95- } ;
9691
9792 // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
9893 // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99- let mut first_activity = 2 ;
94+ let mut first_activity = 1 ;
10095
101- let width = if let [ _, _ , x, ..] = & meta_item[ ..]
96+ let width = if let [ _, x, ..] = & meta_item[ ..]
10297 && let Some ( x) = width ( x)
10398 {
104- first_activity = 3 ;
99+ first_activity = 2 ;
105100 match x. try_into ( ) {
106101 Ok ( x) => x,
107102 Err ( _) => {
108103 dcx. emit_err ( errors:: AutoDiffInvalidWidth {
109- span : meta_item[ 2 ] . span ( ) ,
104+ span : meta_item[ 1 ] . span ( ) ,
110105 width : x,
111106 } ) ;
112107 return AutoDiffAttrs :: error ( ) ;
@@ -150,7 +145,7 @@ mod llvm_enzyme {
150145 } ;
151146
152147 AutoDiffAttrs {
153- mode,
148+ mode : DiffMode :: Error ,
154149 width,
155150 ret_activity : * ret_activity,
156151 input_activity : input_activity. to_vec ( ) ,
@@ -165,6 +160,24 @@ mod llvm_enzyme {
165160 ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
166161 }
167162
163+ pub ( crate ) fn expand_forward (
164+ ecx : & mut ExtCtxt < ' _ > ,
165+ expand_span : Span ,
166+ meta_item : & ast:: MetaItem ,
167+ item : Annotatable ,
168+ ) -> Vec < Annotatable > {
169+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Forward )
170+ }
171+
172+ pub ( crate ) fn expand_reverse (
173+ ecx : & mut ExtCtxt < ' _ > ,
174+ expand_span : Span ,
175+ meta_item : & ast:: MetaItem ,
176+ item : Annotatable ,
177+ ) -> Vec < Annotatable > {
178+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Reverse )
179+ }
180+
168181 /// We expand the autodiff macro to generate a new placeholder function which passes
169182 /// type-checking and can be called by users. The function body of the placeholder function will
170183 /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
198211 /// ```
199212 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200213 /// in CI.
201- pub ( crate ) fn expand (
214+ pub ( crate ) fn expand_with_mode (
202215 ecx : & mut ExtCtxt < ' _ > ,
203216 expand_span : Span ,
204217 meta_item : & ast:: MetaItem ,
205218 mut item : Annotatable ,
219+ mode : DiffMode ,
206220 ) -> Vec < Annotatable > {
207221 if cfg ! ( not( llvm_enzyme) ) {
208222 ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
@@ -287,7 +301,8 @@ mod llvm_enzyme {
287301 ts. pop ( ) ;
288302 let ts: TokenStream = TokenStream :: from_iter ( ts) ;
289303
290- let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
304+ let mut x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
305+ x. mode = mode;
291306 if !x. is_active ( ) {
292307 // We encountered an error, so we return the original item.
293308 // This allows us to potentially parse other attributes.
@@ -964,6 +979,6 @@ mod llvm_enzyme {
964979 trace ! ( "Generated signature: {:?}" , d_sig) ;
965980 ( d_sig, new_inputs, idents, false )
966981 }
967- }
968982
969- pub ( crate ) use llvm_enzyme:: expand;
983+
984+ pub ( crate ) use llvm_enzyme:: { expand, expand_forward, expand_reverse} ;
0 commit comments