Skip to content

Commit fc687f8

Browse files
authored
Add Collective Operations Parity Check Between Scheduler and Inductor Output Code (#136)
Collective Operations Parity Analysis - Compares collective operations between two sources: - Scheduler artifacts: Operations planned in inductor_collective_schedule_*.json files - Generated code: Actual collective calls present in inductor_output_code_*.py files - Counts occurrences of 6 collective operation types: all_reduce, reduce_scatter, all_gather, broadcast, reduce, all_to_all - Calculates the absolute difference (offset) between planned vs actual operations - Generates a collectives_parity.json report for each rank <img width="1093" height="754" alt="Screenshot 2025-08-22 at 9 49 51 AM" src="https://github.com/user-attachments/assets/13fe8bbe-5110-4060-b29f-f90d4f52b26a" /> - Ex: <img width="906" height="504" alt="image" src="https://github.com/user-attachments/assets/7f299a81-0da0-45cc-b80e-84f78cfcb7c4" />
1 parent f0c41db commit fc687f8

File tree

7 files changed

+15374
-3
lines changed

7 files changed

+15374
-3
lines changed

src/cli.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,8 @@ fn handle_all_ranks(
459459
println!("Collective schedules: {}", schedules_path.display());
460460
}
461461

462+
tlparse::parsers::check_collectives_parity(&out_path, &rank_nums)?;
463+
462464
// Process tensor meta fingerprints from all ranks
463465
let tensor_meta = tlparse::parsers::read_tensor_meta_fingerprints(&out_path, &rank_nums)?;
464466
let mut tensor_meta_groups: FxHashMap<String, Vec<u32>> = FxHashMap::default();

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ mod templates;
2525
mod types;
2626

2727
pub use types::{
28-
ArtifactFlags, Diagnostics, DivergenceFlags, DivergenceGroup, GraphAnalysis, GraphRuntime,
29-
RankMetaData, RuntimeAnalysis, RuntimeRankDetail,
28+
ArtifactFlags, CollectivesParityReport, Diagnostics, DivergenceFlags, DivergenceGroup,
29+
GraphAnalysis, GraphCollectivesParity, GraphRuntime, RankMetaData, RuntimeAnalysis,
30+
RuntimeRankDetail,
3031
};
3132

3233
pub use execution_order::{

src/parsers.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,157 @@ pub fn read_collective_schedules(
745745
)
746746
}
747747

748+
pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow::Result<()> {
749+
use regex::Regex;
750+
use std::{collections::HashMap, fs};
751+
752+
// Match c10d functional calls: torch.ops._c10d_functional.<op>.default(
753+
let call_re = Regex::new(
754+
r"torch\s*\.\s*ops\s*\.\s*_?c10d_functional\s*\.\s*([A-Za-z0-9_]+)\s*\.\s*default\s*\(",
755+
)?;
756+
let comment_re = Regex::new(r"(?m)^\s*#.*$|\s#[^0-9a-fA-F].*$|//.*$|(?s)/\*.*?\*/")?;
757+
let html_tag_re = Regex::new(r"(?s)<[^>]*>")?;
758+
759+
for &rank in rank_nums {
760+
let rank_dir = out_path.join(format!("rank_{rank}"));
761+
if !rank_dir.exists() {
762+
continue;
763+
}
764+
765+
// Map compile directory (graph folder) name prefix -> compile ID
766+
let dir_to_compile_id: HashMap<String, String> =
767+
fs::read_to_string(rank_dir.join("compile_directory.json"))
768+
.ok()
769+
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
770+
.and_then(|v| {
771+
v.as_object().map(|obj| {
772+
obj.iter().fold(HashMap::new(), |mut m, (cid, entry)| {
773+
if let Some(arts) = entry.get("artifacts").and_then(|x| x.as_array()) {
774+
for a in arts {
775+
if let Some(url) = a.get("url").and_then(|x| x.as_str()) {
776+
if let Some((prefix, _)) = url.split_once('/') {
777+
m.entry(prefix.to_string())
778+
.or_insert_with(|| cid.to_string());
779+
}
780+
}
781+
}
782+
}
783+
m
784+
})
785+
})
786+
})
787+
.unwrap_or_default();
788+
789+
let mut report = crate::types::CollectivesParityReport {
790+
description: "Difference of # of collectives in scheduler and inductor output code"
791+
.to_string(),
792+
graphs: Vec::new(),
793+
};
794+
795+
for compile_dir in fs::read_dir(&rank_dir)?
796+
.flatten()
797+
.map(|e| e.path())
798+
.filter(|p| p.is_dir())
799+
{
800+
let (mut schedule_path, mut code_path) = (None, None);
801+
for p in fs::read_dir(&compile_dir)?.flatten().map(|e| e.path()) {
802+
let stem = p.file_stem().and_then(|s| s.to_str()).unwrap_or("");
803+
if p.extension() == Some(OsStr::new("json"))
804+
&& stem.starts_with("inductor_collective_schedule")
805+
{
806+
schedule_path = Some(p);
807+
} else if stem.starts_with("inductor_output_code") && code_path.is_none() {
808+
code_path = Some(p);
809+
}
810+
}
811+
812+
let (Some(schedule), Some(code)) = (schedule_path, code_path) else {
813+
continue;
814+
};
815+
816+
let raw_ops: Vec<String> =
817+
serde_json::from_str(&fs::read_to_string(schedule)?).unwrap_or_default();
818+
// Extract and normalize op names from schedule
819+
let normalize_op = |op: &str| -> Option<&'static str> {
820+
let op = op.trim_end_matches('_');
821+
[
822+
"all_reduce",
823+
"reduce_scatter",
824+
"all_gather",
825+
"broadcast",
826+
"all_to_all",
827+
]
828+
.iter()
829+
.find(|&&name| op.contains(name))
830+
.copied()
831+
.or_else(|| {
832+
(op.contains("reduce")
833+
&& !op.contains("all_reduce")
834+
&& !op.contains("reduce_scatter"))
835+
.then_some("reduce")
836+
})
837+
};
838+
839+
let mut schedule_counts: HashMap<&str, usize> = HashMap::new();
840+
for op in &raw_ops {
841+
if let Some(normalized) = normalize_op(op) {
842+
*schedule_counts.entry(normalized).or_insert(0) += 1;
843+
}
844+
}
845+
846+
// Code counts: strip tags and comments, then count calls
847+
let code_clean = comment_re
848+
.replace_all(&html_tag_re.replace_all(&fs::read_to_string(code)?, ""), "")
849+
.into_owned();
850+
let mut code_counts: HashMap<&str, usize> = HashMap::new();
851+
for cap in call_re.captures_iter(&code_clean) {
852+
if let Some(normalized) = normalize_op(cap.get(1).unwrap().as_str()) {
853+
*code_counts.entry(normalized).or_insert(0) += 1;
854+
}
855+
}
856+
857+
// Compute offset over union of all detected ops
858+
let mut all_ops: std::collections::HashSet<&str> =
859+
schedule_counts.keys().copied().collect();
860+
all_ops.extend(code_counts.keys().copied());
861+
let offset: usize = all_ops
862+
.iter()
863+
.map(|&n| {
864+
schedule_counts
865+
.get(n)
866+
.copied()
867+
.unwrap_or(0)
868+
.abs_diff(code_counts.get(n).copied().unwrap_or(0))
869+
})
870+
.sum();
871+
872+
if offset > 0 {
873+
let graph = compile_dir
874+
.file_name()
875+
.and_then(|n| n.to_str())
876+
.unwrap_or("unknown")
877+
.to_string();
878+
let compile_id = dir_to_compile_id
879+
.get(&graph)
880+
.cloned()
881+
.unwrap_or_else(|| "unknown".to_string());
882+
report.graphs.push(crate::types::GraphCollectivesParity {
883+
graph,
884+
compile_id,
885+
offset,
886+
});
887+
}
888+
}
889+
890+
fs::write(
891+
rank_dir.join("collectives_parity.json"),
892+
serde_json::to_string_pretty(&report)?,
893+
)?;
894+
}
895+
896+
Ok(())
897+
}
898+
748899
/// Parses a prefixed JSON file from each multi-rank output directory.
749900
/// It finds the first matching file, calls `parse_fn` on its contents,
750901
/// and collects the `Some(T)` results into a vector.

src/templates.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ PT2 generates <a href='chromium_events.json'>Chromium Trace Events</a> in JSON o
185185
You can download and view them in a tool like <a href='https://ui.perfetto.dev/'>Perfetto</a>.
186186
{{ endif }}
187187
<p>
188+
<a href="collectives_parity.json">Collectives Parity report</a> comparing scheduler and Inductor output code collective operations.
189+
</p>
190+
<p>
188191
Build products below:
189192
</p>
190193
<ul>

src/types.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ pub struct TensorMetaFingerprint {
5353
pub fingerprint: String,
5454
}
5555

56+
#[derive(Debug, Serialize, Deserialize)]
57+
pub struct GraphCollectivesParity {
58+
pub graph: String,
59+
pub compile_id: String,
60+
pub offset: usize,
61+
}
62+
63+
#[derive(Debug, Serialize, Deserialize)]
64+
pub struct CollectivesParityReport {
65+
pub description: String,
66+
pub graphs: Vec<GraphCollectivesParity>,
67+
}
5668
/// Estimated runtime entry for a single op within a graph.
5769
#[derive(Debug, Serialize, Deserialize)]
5870
pub struct OpRuntime {

0 commit comments

Comments
 (0)