@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44use std:: sync:: Arc ;
55use std:: { fs, slice, str} ;
66
7- use libc:: { c_char, c_int, c_void, size_t} ;
7+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
88use llvm:: {
99 LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
1010} ;
11+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
1112use rustc_codegen_ssa:: back:: link:: ensure_removed;
1213use rustc_codegen_ssa:: back:: write:: {
1314 BitcodeSection , CodegenContext , EmitObj , ModuleConfig , TargetMachineFactoryConfig ,
@@ -27,7 +28,7 @@ use rustc_session::config::{
2728use rustc_span:: InnerSpan ;
2829use rustc_span:: symbol:: sym;
2930use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
30- use tracing:: debug;
31+ use tracing:: { debug, trace } ;
3132
3233use crate :: back:: lto:: ThinBuffer ;
3334use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -39,7 +40,13 @@ use crate::errors::{
3940 WithLlvmError , WriteBytecode ,
4041} ;
4142use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
42- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
43+ use crate :: llvm:: {
44+ self , AttributeKind , DiagnosticInfo , LLVMCreateStringAttribute , LLVMGetFirstFunction ,
45+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
46+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
47+ LLVMRustAddFunctionAttributes , LLVMRustGetEnumAttributeAtIndex ,
48+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
49+ } ;
4350use crate :: type_:: Type ;
4451use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
4552
@@ -515,9 +522,34 @@ pub(crate) unsafe fn llvm_optimize(
515522 config : & ModuleConfig ,
516523 opt_level : config:: OptLevel ,
517524 opt_stage : llvm:: OptStage ,
525+ skip_size_increasing_opts : bool ,
518526) -> Result < ( ) , FatalError > {
519- let unroll_loops =
520- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
527+ // Enzyme:
528+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
529+ // source code. However, benchmarks show that optimizations increasing the code size
530+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
531+ // and finally re-optimize the module, now with all optimizations available.
532+ // TODO: In a future update we could figure out how to only optimize functions getting
533+ // differentiated.
534+
535+ let unroll_loops;
536+ let vectorize_slp;
537+ let vectorize_loop;
538+
539+ if skip_size_increasing_opts {
540+ unroll_loops = false ;
541+ vectorize_slp = false ;
542+ vectorize_loop = false ;
543+ } else {
544+ unroll_loops =
545+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
546+ vectorize_slp = config. vectorize_slp ;
547+ vectorize_loop = config. vectorize_loop ;
548+ }
549+ trace ! (
550+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
551+ unroll_loops, vectorize_slp, vectorize_loop
552+ ) ;
521553 let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
522554 let pgo_gen_path = get_pgo_gen_path ( config) ;
523555 let pgo_use_path = get_pgo_use_path ( config) ;
@@ -581,8 +613,8 @@ pub(crate) unsafe fn llvm_optimize(
581613 using_thin_buffers,
582614 config. merge_functions ,
583615 unroll_loops,
584- config . vectorize_slp ,
585- config . vectorize_loop ,
616+ vectorize_slp,
617+ vectorize_loop,
586618 config. no_builtins ,
587619 config. emit_lifetime_markers ,
588620 sanitizer_options. as_ref ( ) ,
@@ -605,6 +637,113 @@ pub(crate) unsafe fn llvm_optimize(
605637 result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
606638}
607639
640+ pub ( crate ) fn differentiate (
641+ module : & ModuleCodegen < ModuleLlvm > ,
642+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
643+ diff_items : Vec < AutoDiffItem > ,
644+ config : & ModuleConfig ,
645+ ) -> Result < ( ) , FatalError > {
646+ for item in & diff_items {
647+ trace ! ( "{}" , item) ;
648+ }
649+
650+ let llmod = module. module_llvm . llmod ( ) ;
651+ let llcx = & module. module_llvm . llcx ;
652+ let diag_handler = cgcx. create_dcx ( ) ;
653+
654+ // Before dumping the module, we want all the tt to become part of the module.
655+ for item in diff_items. iter ( ) {
656+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
657+ let fn_def: Option < & llvm:: Value > =
658+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
659+ let fn_def = match fn_def {
660+ Some ( x) => x,
661+ None => {
662+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
663+ src : item. source . clone ( ) ,
664+ target : item. target . clone ( ) ,
665+ error : "could not find source function" . to_owned ( ) ,
666+ } ) ) ;
667+ }
668+ } ;
669+ let tgt_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
670+ dbg ! ( "Target name: {:?}" , & tgt_name) ;
671+ let fn_target: Option < & llvm:: Value > =
672+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, tgt_name. as_ptr ( ) ) } ;
673+ let fn_target = match fn_target {
674+ Some ( x) => x,
675+ None => {
676+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
677+ src : item. source . clone ( ) ,
678+ target : item. target . clone ( ) ,
679+ error : "could not find target function" . to_owned ( ) ,
680+ } ) ) ;
681+ }
682+ } ;
683+
684+ crate :: builder:: add_opt_dbg_helper2 ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
685+ }
686+
687+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
688+ // which Enzyme doesn't understand.
689+ unsafe {
690+ let mut f = LLVMGetFirstFunction ( llmod) ;
691+ loop {
692+ if let Some ( lf) = f {
693+ f = LLVMGetNextFunction ( lf) ;
694+ let myhwattr = "enzyme_hw" ;
695+ let attr = LLVMGetStringAttributeAtIndex (
696+ lf,
697+ c_uint:: MAX ,
698+ myhwattr. as_ptr ( ) as * const c_char ,
699+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
700+ ) ;
701+ if LLVMIsStringAttribute ( attr) {
702+ LLVMRemoveStringAttributeAtIndex (
703+ lf,
704+ c_uint:: MAX ,
705+ myhwattr. as_ptr ( ) as * const c_char ,
706+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
707+ ) ;
708+ } else {
709+ LLVMRustRemoveEnumAttributeAtIndex (
710+ lf,
711+ c_uint:: MAX ,
712+ AttributeKind :: SanitizeHWAddress ,
713+ ) ;
714+ }
715+ } else {
716+ break ;
717+ }
718+ }
719+ }
720+
721+ if let Some ( opt_level) = config. opt_level {
722+ let opt_stage = match cgcx. lto {
723+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
724+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
725+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
726+ _ => llvm:: OptStage :: PreLinkNoLTO ,
727+ } ;
728+ let skip_size_increasing_opts = false ;
729+ dbg ! ( "Running Module Optimization after differentiation" ) ;
730+ unsafe {
731+ llvm_optimize (
732+ cgcx,
733+ diag_handler. handle ( ) ,
734+ module,
735+ config,
736+ opt_level,
737+ opt_stage,
738+ skip_size_increasing_opts,
739+ ) ?
740+ } ;
741+ }
742+ dbg ! ( "Done with differentiate()" ) ;
743+
744+ Ok ( ( ) )
745+ }
746+
608747// Unsafe due to LLVM calls.
609748pub ( crate ) unsafe fn optimize (
610749 cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -627,14 +766,68 @@ pub(crate) unsafe fn optimize(
627766 unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
628767 }
629768
769+ // This code enables Enzyme to differentiate code containing Rust enums.
770+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
771+ // away the enums and allows Enzyme to understand why a value can be of different types in
772+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
773+ // rest of the compilation.
774+ #[ cfg( llvm_enzyme) ]
775+ unsafe {
776+ let mut f = LLVMGetFirstFunction ( llmod) ;
777+ loop {
778+ if let Some ( lf) = f {
779+ f = LLVMGetNextFunction ( lf) ;
780+ let myhwattr = "enzyme_hw" ;
781+ let myhwv = "" ;
782+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
783+ lf,
784+ c_uint:: MAX ,
785+ AttributeKind :: SanitizeHWAddress ,
786+ ) ;
787+ if LLVMIsEnumAttribute ( prevattr) {
788+ let attr = LLVMCreateStringAttribute (
789+ llcx,
790+ myhwattr. as_ptr ( ) as * const c_char ,
791+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
792+ myhwv. as_ptr ( ) as * const c_char ,
793+ myhwv. as_bytes ( ) . len ( ) as c_uint ,
794+ ) ;
795+ LLVMRustAddFunctionAttributes ( lf, c_uint:: MAX , & attr, 1 ) ;
796+ } else {
797+ LLVMRustAddEnumAttributeAtIndex (
798+ llcx,
799+ lf,
800+ c_uint:: MAX ,
801+ AttributeKind :: SanitizeHWAddress ,
802+ ) ;
803+ }
804+ } else {
805+ break ;
806+ }
807+ }
808+ }
809+
630810 if let Some ( opt_level) = config. opt_level {
631811 let opt_stage = match cgcx. lto {
632812 Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
633813 Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
634814 _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
635815 _ => llvm:: OptStage :: PreLinkNoLTO ,
636816 } ;
637- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
817+
818+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
819+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
820+ return unsafe {
821+ llvm_optimize (
822+ cgcx,
823+ dcx,
824+ module,
825+ config,
826+ opt_level,
827+ opt_stage,
828+ skip_size_increasing_opts,
829+ )
830+ } ;
638831 }
639832 Ok ( ( ) )
640833}
0 commit comments