@@ -21,7 +21,7 @@ mod llvm_enzyme {
2121 MetaItemInner , PatKind , Path , PathSegment , TyKind , Visibility ,
2222 } ;
2323 use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
24- use rustc_span:: { Ident , Span , Symbol , kw , sym} ;
24+ use rustc_span:: { Ident , Span , Symbol , sym} ;
2525 use thin_vec:: { ThinVec , thin_vec} ;
2626 use tracing:: { debug, trace} ;
2727
@@ -183,11 +183,8 @@ mod llvm_enzyme {
183183 }
184184
185185 /// We expand the autodiff macro to generate a new placeholder function which passes
186- /// type-checking and can be called by users. The function body of the placeholder function will
187- /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
188- /// should just prevent early inlining and optimizations which alter the function signature.
189- /// The exact signature of the generated function depends on the configuration provided by the
190- /// user, but here is an example:
186+ /// type-checking and can be called by users. The exact signature of the generated function
187+ /// depends on the configuration provided by the user, but here is an example:
191188 ///
192189 /// ```
193190 /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -203,14 +200,8 @@ mod llvm_enzyme {
203200 /// f32::sin(**x)
204201 /// }
205202 /// #[rustc_autodiff(Reverse, Duplicated, Active)]
206- /// #[inline(never)]
207203 /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
208- /// unsafe {
209- /// asm!("NOP");
210- /// };
211- /// ::core::hint::black_box(sin(x));
212- /// ::core::hint::black_box((dx, dret));
213- /// ::core::hint::black_box(sin(x))
204+ /// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
214205 /// }
215206 /// ```
216207 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -330,22 +321,20 @@ mod llvm_enzyme {
330321 }
331322 let span = ecx. with_def_site_ctxt ( expand_span) ;
332323
333- let ( d_sig, idents , errored ) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
324+ let d_sig = gen_enzyme_decl ( ecx, & sig, & x, span) ;
334325
335326 let d_body = gen_enzyme_body (
336327 ecx,
337328 & d_sig,
338329 primal,
339330 span,
340- idents,
341- errored,
342331 first_ident ( & meta_item_vec[ 0 ] ) ,
343332 & generics,
344333 impl_of_trait,
345334 ) ;
346335
347336 // The first element of it is the name of the function to be generated
348- let asdf = Box :: new ( ast:: Fn {
337+ let d_fn = Box :: new ( ast:: Fn {
349338 defaultness : ast:: Defaultness :: Final ,
350339 sig : d_sig,
351340 ident : first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -442,7 +431,7 @@ mod llvm_enzyme {
442431 let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
443432 let d_annotatable = match & item {
444433 Annotatable :: AssocItem ( _, _) => {
445- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf ) ;
434+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( d_fn ) ;
446435 let d_fn = P ( ast:: AssocItem {
447436 attrs : thin_vec ! [ d_attr] ,
448437 id : ast:: DUMMY_NODE_ID ,
@@ -454,13 +443,13 @@ mod llvm_enzyme {
454443 Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
455444 }
456445 Annotatable :: Item ( _) => {
457- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
446+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
458447 d_fn. vis = vis;
459448
460449 Annotatable :: Item ( d_fn)
461450 }
462451 Annotatable :: Stmt ( _) => {
463- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
452+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
464453 d_fn. vis = vis;
465454
466455 Annotatable :: Stmt ( P ( ast:: Stmt {
@@ -525,14 +514,8 @@ mod llvm_enzyme {
525514 . into ( ) ,
526515 ) ;
527516
528- let enzyme_path = ecx. path (
529- span,
530- vec ! [
531- Ident :: from_str( "std" ) ,
532- Ident :: from_str( "intrinsics" ) ,
533- Ident :: with_dummy_span( sym:: enzyme_autodiff) ,
534- ] ,
535- ) ;
517+ let enzyme_path_idents = ecx. std_path ( & [ sym:: intrinsics, sym:: enzyme_autodiff] ) ;
518+ let enzyme_path = ecx. path ( span, enzyme_path_idents) ;
536519 let call_expr = ecx. expr_call (
537520 span,
538521 ecx. expr_path ( enzyme_path) ,
@@ -591,25 +574,6 @@ mod llvm_enzyme {
591574 ecx. expr_path ( path)
592575 }
593576
594- // Will generate a body of the type:
595- // ```
596- // primal(args);
597- // std::intrinsics::enzyme_autodiff(primal, diff, (args))
598- // }
599- // ```
600- fn init_body_helper (
601- ecx : & ExtCtxt < ' _ > ,
602- span : Span ,
603- primal : Ident ,
604- idents : & [ Ident ] ,
605- _errored : bool ,
606- generics : & Generics ,
607- ) -> P < ast:: Block > {
608- let _primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
609- let body = ecx. block ( span, ThinVec :: new ( ) ) ;
610- body
611- }
612-
613577 /// We only want this function to type-check, since we will replace the body
614578 /// later on llvm level. Using `loop {}` does not cover all return types anymore,
615579 /// so instead we manually build something that should pass the type checker.
@@ -623,8 +587,6 @@ mod llvm_enzyme {
623587 d_sig : & ast:: FnSig ,
624588 primal : Ident ,
625589 span : Span ,
626- idents : Vec < Ident > ,
627- errored : bool ,
628590 diff_ident : Ident ,
629591 generics : & Generics ,
630592 is_impl : bool ,
@@ -633,87 +595,22 @@ mod llvm_enzyme {
633595
634596 // Add a call to the primal function to prevent it from being inlined
635597 // and call `enzyme_autodiff` intrinsic (this also covers the return type)
636- let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
637-
638- body. stmts . push ( call_enzyme_autodiff (
639- ecx,
640- primal,
641- diff_ident,
642- new_decl_span,
643- d_sig,
644- generics,
645- is_impl,
646- ) ) ;
598+ let body = ecx. block (
599+ span,
600+ thin_vec ! [ call_enzyme_autodiff(
601+ ecx,
602+ primal,
603+ diff_ident,
604+ new_decl_span,
605+ d_sig,
606+ generics,
607+ is_impl,
608+ ) ] ,
609+ ) ;
647610
648611 body
649612 }
650613
651- fn gen_primal_call (
652- ecx : & ExtCtxt < ' _ > ,
653- span : Span ,
654- primal : Ident ,
655- idents : & [ Ident ] ,
656- generics : & Generics ,
657- ) -> P < ast:: Expr > {
658- let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
659-
660- if has_self {
661- let args: ThinVec < _ > =
662- idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
663- let self_expr = ecx. expr_self ( span) ;
664- ecx. expr_method_call ( span, self_expr, primal, args)
665- } else {
666- let args: ThinVec < _ > =
667- idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
668- let mut primal_path = ecx. path_ident ( span, primal) ;
669-
670- let is_generic = !generics. params . is_empty ( ) ;
671-
672- match ( is_generic, primal_path. segments . last_mut ( ) ) {
673- ( true , Some ( function_path) ) => {
674- let primal_generic_types = generics
675- . params
676- . iter ( )
677- . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
678-
679- let generated_generic_types = primal_generic_types
680- . map ( |type_param| {
681- let generic_param = TyKind :: Path (
682- None ,
683- ast:: Path {
684- span,
685- segments : thin_vec ! [ ast:: PathSegment {
686- ident: type_param. ident,
687- args: None ,
688- id: ast:: DUMMY_NODE_ID ,
689- } ] ,
690- tokens : None ,
691- } ,
692- ) ;
693-
694- ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
695- id : type_param. id ,
696- span,
697- kind : generic_param,
698- tokens : None ,
699- } ) ) )
700- } )
701- . collect ( ) ;
702-
703- function_path. args =
704- Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
705- span,
706- args : generated_generic_types,
707- } ) ) ) ;
708- }
709- _ => { }
710- }
711-
712- let primal_call_expr = ecx. expr_path ( primal_path) ;
713- ecx. expr_call ( span, primal_call_expr, args)
714- }
715- }
716-
717614 // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
718615 // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
719616 // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -730,7 +627,7 @@ mod llvm_enzyme {
730627 sig : & ast:: FnSig ,
731628 x : & AutoDiffAttrs ,
732629 span : Span ,
733- ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
630+ ) -> ast:: FnSig {
734631 let dcx = ecx. sess . dcx ( ) ;
735632 let has_ret = has_ret ( & sig. decl . output ) ;
736633 let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -742,7 +639,7 @@ mod llvm_enzyme {
742639 found : num_activities,
743640 } ) ;
744641 // This is not the right signature, but we can continue parsing.
745- return ( sig. clone ( ) , vec ! [ ] , true ) ;
642+ return sig. clone ( ) ;
746643 }
747644 assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
748645 assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -785,7 +682,7 @@ mod llvm_enzyme {
785682
786683 if errors {
787684 // This is not the right signature, but we can continue parsing.
788- return ( sig. clone ( ) , idents , true ) ;
685+ return sig. clone ( ) ;
789686 }
790687
791688 let unsafe_activities = x
@@ -993,7 +890,7 @@ mod llvm_enzyme {
993890 }
994891 let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
995892 trace ! ( "Generated signature: {:?}" , d_sig) ;
996- ( d_sig, idents , false )
893+ d_sig
997894 }
998895}
999896
0 commit comments