@@ -584,12 +584,10 @@ fn thin_lto(
584584 }
585585}
586586
587- fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] , module : & mut ModuleCodegen < ModuleLlvm > ) {
587+ fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] ) {
588588 for & val in ad {
589+ // We intentionally don't use a wildcard, to not forget handling anything new.
589590 match val {
590- config:: AutoDiff :: PrintModBefore => {
591- unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
592- }
593591 config:: AutoDiff :: PrintPerf => {
594592 llvm:: set_print_perf ( true ) ;
595593 }
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
603601 llvm:: set_inline ( true ) ;
604602 }
605603 config:: AutoDiff :: LooseTypes => {
606- llvm:: set_loose_types ( false ) ;
604+ llvm:: set_loose_types ( true ) ;
607605 }
608606 config:: AutoDiff :: PrintSteps => {
609607 llvm:: set_print ( true ) ;
610608 }
611- // We handle this below
609+ // We handle this in the PassWrapper.cpp
610+ config:: AutoDiff :: PrintPasses => { }
611+ // We handle this in the PassWrapper.cpp
612+ config:: AutoDiff :: PrintModBefore => { }
613+ // We handle this in the PassWrapper.cpp
612614 config:: AutoDiff :: PrintModAfter => { }
613- // We handle this below
615+ // We handle this in the PassWrapper.cpp
614616 config:: AutoDiff :: PrintModFinal => { }
615617 // This is required and already checked
616618 config:: AutoDiff :: Enable => { }
619+ // We handle this below
620+ config:: AutoDiff :: NoPostopt => { }
617621 }
618622 }
619623 // This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
647651 // We then run the llvm_optimize function a second time, to optimize the code which we generated
648652 // in the enzyme differentiation pass.
649653 let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
650- let stage =
651- if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD } ;
654+ let stage = if thin {
655+ write:: AutodiffStage :: PreAD
656+ } else {
657+ if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD }
658+ } ;
652659
653660 if enable_ad {
654- enable_autodiff_settings ( & config. autodiff , module ) ;
661+ enable_autodiff_settings ( & config. autodiff ) ;
655662 }
656663
657664 unsafe {
658665 write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
659666 }
660667
661- if cfg ! ( llvm_enzyme) && enable_ad {
662- // This is the post-autodiff IR, mainly used for testing and educational purposes.
663- if config. autodiff . contains ( & config:: AutoDiff :: PrintModAfter ) {
664- unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
665- }
666-
668+ if cfg ! ( llvm_enzyme) && enable_ad && !thin {
667669 let opt_stage = llvm:: OptStage :: FatLTO ;
668670 let stage = write:: AutodiffStage :: PostAD ;
669- unsafe {
670- write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
671+ if !config. autodiff . contains ( & config:: AutoDiff :: NoPostopt ) {
672+ unsafe {
673+ write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
674+ }
671675 }
672676
673677 // This is the final IR, so people should be able to inspect the optimized autodiff output,
0 commit comments