@@ -26,6 +26,16 @@ mod llvm_enzyme {
2626
2727 use crate :: errors;
2828
29+ pub ( crate ) fn outer_normal_attr (
30+ kind : & P < rustc_ast:: NormalAttr > ,
31+ id : rustc_ast:: AttrId ,
32+ span : Span ,
33+ ) -> rustc_ast:: Attribute {
34+ let style = rustc_ast:: AttrStyle :: Outer ;
35+ let kind = rustc_ast:: AttrKind :: Normal ( kind. clone ( ) ) ;
36+ rustc_ast:: Attribute { kind, id, style, span }
37+ }
38+
2939 // If we have a default `()` return type or explicitley `()` return type,
3040 // then we often can skip doing some work.
3141 fn has_ret ( ty : & FnRetTy ) -> bool {
@@ -224,20 +234,8 @@ mod llvm_enzyme {
224234 . filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
225235 . count ( ) as u32 ;
226236 let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
227- let new_decl_span = d_sig. span ;
228237 let d_body = gen_enzyme_body (
229- ecx,
230- & x,
231- n_active,
232- & sig,
233- & d_sig,
234- primal,
235- & new_args,
236- span,
237- sig_span,
238- new_decl_span,
239- idents,
240- errored,
238+ ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
241239 ) ;
242240 let d_ident = first_ident ( & meta_item_vec[ 0 ] ) ;
243241
@@ -270,36 +268,39 @@ mod llvm_enzyme {
270268 } ;
271269 let inline_never_attr = P ( ast:: NormalAttr { item : inline_item, tokens : None } ) ;
272270 let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
273- let attr: ast:: Attribute = ast:: Attribute {
274- kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
275- id : new_id,
276- style : ast:: AttrStyle :: Outer ,
277- span,
278- } ;
271+ let attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
279272 let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
280- let inline_never: ast:: Attribute = ast:: Attribute {
281- kind : ast:: AttrKind :: Normal ( inline_never_attr) ,
282- id : new_id,
283- style : ast:: AttrStyle :: Outer ,
284- span,
285- } ;
273+ let inline_never = outer_normal_attr ( & inline_never_attr, new_id, span) ;
274+
275+ // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
276+ fn same_attribute ( attr : & ast:: AttrKind , item : & ast:: AttrKind ) -> bool {
277+ match ( attr, item) {
278+ ( ast:: AttrKind :: Normal ( a) , ast:: AttrKind :: Normal ( b) ) => {
279+ let a = & a. item . path ;
280+ let b = & b. item . path ;
281+ a. segments . len ( ) == b. segments . len ( )
282+ && a. segments . iter ( ) . zip ( b. segments . iter ( ) ) . all ( |( a, b) | a. ident == b. ident )
283+ }
284+ _ => false ,
285+ }
286+ }
286287
287288 // Don't add it multiple times:
288289 let orig_annotatable: Annotatable = match item {
289290 Annotatable :: Item ( ref mut iitem) => {
290- if !iitem. attrs . iter ( ) . any ( |a| a . id == attr. id ) {
291+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & attr. kind ) ) {
291292 iitem. attrs . push ( attr) ;
292293 }
293- if !iitem. attrs . iter ( ) . any ( |a| a . id == inline_never. id ) {
294+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & inline_never. kind ) ) {
294295 iitem. attrs . push ( inline_never. clone ( ) ) ;
295296 }
296297 Annotatable :: Item ( iitem. clone ( ) )
297298 }
298299 Annotatable :: AssocItem ( ref mut assoc_item, i @ Impl ) => {
299- if !assoc_item. attrs . iter ( ) . any ( |a| a . id == attr. id ) {
300+ if !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & attr. kind ) ) {
300301 assoc_item. attrs . push ( attr) ;
301302 }
302- if !assoc_item. attrs . iter ( ) . any ( |a| a . id == inline_never. id ) {
303+ if !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a . kind , & inline_never. kind ) ) {
303304 assoc_item. attrs . push ( inline_never. clone ( ) ) ;
304305 }
305306 Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
@@ -314,13 +315,7 @@ mod llvm_enzyme {
314315 delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
315316 tokens : ts,
316317 } ) ;
317- let d_attr: ast:: Attribute = ast:: Attribute {
318- kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
319- id : new_id,
320- style : ast:: AttrStyle :: Outer ,
321- span,
322- } ;
323-
318+ let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
324319 let d_annotatable = if is_impl {
325320 let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
326321 let d_fn = P ( ast:: AssocItem {
@@ -361,30 +356,27 @@ mod llvm_enzyme {
361356 ty
362357 }
363358
364- /// We only want this function to type-check, since we will replace the body
365- /// later on llvm level. Using `loop {}` does not cover all return types anymore,
366- /// so instead we build something that should pass. We also add a inline_asm
367- /// line, as one more barrier for rustc to prevent inlining of this function.
368- /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
369- /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
370- /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
371- /// this function (which should never happen, since it is only a placeholder).
372- /// Finally, we also add back_box usages of all input arguments, to prevent rustc
373- /// from optimizing any arguments away.
374- fn gen_enzyme_body (
359+ // Will generate a body of the type:
360+ // ```
361+ // {
362+ // unsafe {
363+ // asm!("NOP");
364+ // }
365+ // ::core::hint::black_box(primal(args));
366+ // ::core::hint::black_box((args, ret));
367+ // <This part remains to be done by following function>
368+ // }
369+ // ```
370+ fn init_body_helper (
375371 ecx : & ExtCtxt < ' _ > ,
376- x : & AutoDiffAttrs ,
377- n_active : u32 ,
378- sig : & ast:: FnSig ,
379- d_sig : & ast:: FnSig ,
372+ span : Span ,
380373 primal : Ident ,
381374 new_names : & [ String ] ,
382- span : Span ,
383375 sig_span : Span ,
384376 new_decl_span : Span ,
385- idents : Vec < Ident > ,
377+ idents : & [ Ident ] ,
386378 errored : bool ,
387- ) -> P < ast:: Block > {
379+ ) -> ( P < ast:: Block > , P < ast :: Expr > , P < ast :: Expr > , P < ast :: Expr > ) {
388380 let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
389381 let noop = ast:: InlineAsm {
390382 asm_macro : ast:: AsmMacro :: Asm ,
@@ -433,6 +425,51 @@ mod llvm_enzyme {
433425 }
434426 body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
435427
428+ ( body, primal_call, black_box_primal_call, blackbox_call_expr)
429+ }
430+
431+ /// We only want this function to type-check, since we will replace the body
432+ /// later on llvm level. Using `loop {}` does not cover all return types anymore,
433+ /// so instead we manually build something that should pass the type checker.
434+ /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
435+ /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
436+ /// bug would ever try to accidentially differentiate this placeholder function body.
437+ /// Finally, we also add back_box usages of all input arguments, to prevent rustc
438+ /// from optimizing any arguments away.
439+ fn gen_enzyme_body (
440+ ecx : & ExtCtxt < ' _ > ,
441+ x : & AutoDiffAttrs ,
442+ n_active : u32 ,
443+ sig : & ast:: FnSig ,
444+ d_sig : & ast:: FnSig ,
445+ primal : Ident ,
446+ new_names : & [ String ] ,
447+ span : Span ,
448+ sig_span : Span ,
449+ idents : Vec < Ident > ,
450+ errored : bool ,
451+ ) -> P < ast:: Block > {
452+ let new_decl_span = d_sig. span ;
453+
454+ // Just adding some default inline-asm and black_box usages to prevent early inlining
455+ // and optimizations which alter the function signature.
456+ //
457+ // The bb_primal_call is the black_box call of the primal function. We keep it around,
458+ // since it has the convenient property of returning the type of the primal function,
459+ // Remember, we only care to match types here.
460+ // No matter which return we pick, we always wrap it into a std::hint::black_box call,
461+ // to prevent rustc from propagating it into the caller.
462+ let ( mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper (
463+ ecx,
464+ span,
465+ primal,
466+ new_names,
467+ sig_span,
468+ new_decl_span,
469+ & idents,
470+ errored,
471+ ) ;
472+
436473 if !has_ret ( & d_sig. decl . output ) {
437474 // there is no return type that we have to match, () works fine.
438475 return body;
@@ -444,7 +481,7 @@ mod llvm_enzyme {
444481
445482 if primal_ret && n_active == 0 && x. mode . is_rev ( ) {
446483 // We only have the primal ret.
447- body. stmts . push ( ecx. stmt_expr ( black_box_primal_call ) ) ;
484+ body. stmts . push ( ecx. stmt_expr ( bb_primal_call ) ) ;
448485 return body;
449486 }
450487
@@ -536,11 +573,11 @@ mod llvm_enzyme {
536573 return body;
537574 }
538575 [ arg] => {
539- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
576+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
540577 }
541578 args => {
542579 let ret_tuple: P < ast:: Expr > = ecx. expr_tuple ( span, args. into ( ) ) ;
543- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ ret_tuple] ) ;
580+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ ret_tuple] ) ;
544581 }
545582 }
546583 assert ! ( has_ret( & d_sig. decl. output) ) ;
@@ -553,7 +590,7 @@ mod llvm_enzyme {
553590 ecx : & ExtCtxt < ' _ > ,
554591 span : Span ,
555592 primal : Ident ,
556- idents : Vec < Ident > ,
593+ idents : & [ Ident ] ,
557594 ) -> P < ast:: Expr > {
558595 let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
559596 if has_self {
0 commit comments