@@ -745,6 +745,157 @@ pub fn read_collective_schedules(
745
745
)
746
746
}
747
747
748
+ pub fn check_collectives_parity ( out_path : & PathBuf , rank_nums : & [ u32 ] ) -> anyhow:: Result < ( ) > {
749
+ use regex:: Regex ;
750
+ use std:: { collections:: HashMap , fs} ;
751
+
752
+ // Match c10d functional calls: torch.ops._c10d_functional.<op>.default(
753
+ let call_re = Regex :: new (
754
+ r"torch\s*\.\s*ops\s*\.\s*_?c10d_functional\s*\.\s*([A-Za-z0-9_]+)\s*\.\s*default\s*\(" ,
755
+ ) ?;
756
+ let comment_re = Regex :: new ( r"(?m)^\s*#.*$|\s#[^0-9a-fA-F].*$|//.*$|(?s)/\*.*?\*/" ) ?;
757
+ let html_tag_re = Regex :: new ( r"(?s)<[^>]*>" ) ?;
758
+
759
+ for & rank in rank_nums {
760
+ let rank_dir = out_path. join ( format ! ( "rank_{rank}" ) ) ;
761
+ if !rank_dir. exists ( ) {
762
+ continue ;
763
+ }
764
+
765
+ // Map compile directory (graph folder) name prefix -> compile ID
766
+ let dir_to_compile_id: HashMap < String , String > =
767
+ fs:: read_to_string ( rank_dir. join ( "compile_directory.json" ) )
768
+ . ok ( )
769
+ . and_then ( |s| serde_json:: from_str :: < serde_json:: Value > ( & s) . ok ( ) )
770
+ . and_then ( |v| {
771
+ v. as_object ( ) . map ( |obj| {
772
+ obj. iter ( ) . fold ( HashMap :: new ( ) , |mut m, ( cid, entry) | {
773
+ if let Some ( arts) = entry. get ( "artifacts" ) . and_then ( |x| x. as_array ( ) ) {
774
+ for a in arts {
775
+ if let Some ( url) = a. get ( "url" ) . and_then ( |x| x. as_str ( ) ) {
776
+ if let Some ( ( prefix, _) ) = url. split_once ( '/' ) {
777
+ m. entry ( prefix. to_string ( ) )
778
+ . or_insert_with ( || cid. to_string ( ) ) ;
779
+ }
780
+ }
781
+ }
782
+ }
783
+ m
784
+ } )
785
+ } )
786
+ } )
787
+ . unwrap_or_default ( ) ;
788
+
789
+ let mut report = crate :: types:: CollectivesParityReport {
790
+ description : "Difference of # of collectives in scheduler and inductor output code"
791
+ . to_string ( ) ,
792
+ graphs : Vec :: new ( ) ,
793
+ } ;
794
+
795
+ for compile_dir in fs:: read_dir ( & rank_dir) ?
796
+ . flatten ( )
797
+ . map ( |e| e. path ( ) )
798
+ . filter ( |p| p. is_dir ( ) )
799
+ {
800
+ let ( mut schedule_path, mut code_path) = ( None , None ) ;
801
+ for p in fs:: read_dir ( & compile_dir) ?. flatten ( ) . map ( |e| e. path ( ) ) {
802
+ let stem = p. file_stem ( ) . and_then ( |s| s. to_str ( ) ) . unwrap_or ( "" ) ;
803
+ if p. extension ( ) == Some ( OsStr :: new ( "json" ) )
804
+ && stem. starts_with ( "inductor_collective_schedule" )
805
+ {
806
+ schedule_path = Some ( p) ;
807
+ } else if stem. starts_with ( "inductor_output_code" ) && code_path. is_none ( ) {
808
+ code_path = Some ( p) ;
809
+ }
810
+ }
811
+
812
+ let ( Some ( schedule) , Some ( code) ) = ( schedule_path, code_path) else {
813
+ continue ;
814
+ } ;
815
+
816
+ let raw_ops: Vec < String > =
817
+ serde_json:: from_str ( & fs:: read_to_string ( schedule) ?) . unwrap_or_default ( ) ;
818
+ // Extract and normalize op names from schedule
819
+ let normalize_op = |op : & str | -> Option < & ' static str > {
820
+ let op = op. trim_end_matches ( '_' ) ;
821
+ [
822
+ "all_reduce" ,
823
+ "reduce_scatter" ,
824
+ "all_gather" ,
825
+ "broadcast" ,
826
+ "all_to_all" ,
827
+ ]
828
+ . iter ( )
829
+ . find ( |& & name| op. contains ( name) )
830
+ . copied ( )
831
+ . or_else ( || {
832
+ ( op. contains ( "reduce" )
833
+ && !op. contains ( "all_reduce" )
834
+ && !op. contains ( "reduce_scatter" ) )
835
+ . then_some ( "reduce" )
836
+ } )
837
+ } ;
838
+
839
+ let mut schedule_counts: HashMap < & str , usize > = HashMap :: new ( ) ;
840
+ for op in & raw_ops {
841
+ if let Some ( normalized) = normalize_op ( op) {
842
+ * schedule_counts. entry ( normalized) . or_insert ( 0 ) += 1 ;
843
+ }
844
+ }
845
+
846
+ // Code counts: strip tags and comments, then count calls
847
+ let code_clean = comment_re
848
+ . replace_all ( & html_tag_re. replace_all ( & fs:: read_to_string ( code) ?, "" ) , "" )
849
+ . into_owned ( ) ;
850
+ let mut code_counts: HashMap < & str , usize > = HashMap :: new ( ) ;
851
+ for cap in call_re. captures_iter ( & code_clean) {
852
+ if let Some ( normalized) = normalize_op ( cap. get ( 1 ) . unwrap ( ) . as_str ( ) ) {
853
+ * code_counts. entry ( normalized) . or_insert ( 0 ) += 1 ;
854
+ }
855
+ }
856
+
857
+ // Compute offset over union of all detected ops
858
+ let mut all_ops: std:: collections:: HashSet < & str > =
859
+ schedule_counts. keys ( ) . copied ( ) . collect ( ) ;
860
+ all_ops. extend ( code_counts. keys ( ) . copied ( ) ) ;
861
+ let offset: usize = all_ops
862
+ . iter ( )
863
+ . map ( |& n| {
864
+ schedule_counts
865
+ . get ( n)
866
+ . copied ( )
867
+ . unwrap_or ( 0 )
868
+ . abs_diff ( code_counts. get ( n) . copied ( ) . unwrap_or ( 0 ) )
869
+ } )
870
+ . sum ( ) ;
871
+
872
+ if offset > 0 {
873
+ let graph = compile_dir
874
+ . file_name ( )
875
+ . and_then ( |n| n. to_str ( ) )
876
+ . unwrap_or ( "unknown" )
877
+ . to_string ( ) ;
878
+ let compile_id = dir_to_compile_id
879
+ . get ( & graph)
880
+ . cloned ( )
881
+ . unwrap_or_else ( || "unknown" . to_string ( ) ) ;
882
+ report. graphs . push ( crate :: types:: GraphCollectivesParity {
883
+ graph,
884
+ compile_id,
885
+ offset,
886
+ } ) ;
887
+ }
888
+ }
889
+
890
+ fs:: write (
891
+ rank_dir. join ( "collectives_parity.json" ) ,
892
+ serde_json:: to_string_pretty ( & report) ?,
893
+ ) ?;
894
+ }
895
+
896
+ Ok ( ( ) )
897
+ }
898
+
748
899
/// Parses a prefixed JSON file from each multi-rank output directory.
749
900
/// It finds the first matching file, calls `parse_fn` on its contents,
750
901
/// and collects the `Some(T)` results into a vector.
0 commit comments