Skip to content

Commit 5938ed9

Browse files
authored
Add validation to ensure collectives are paired with a wait op to collective parity json (#143)
This PR introduces a new field to the existing collectives_parity.json on the tlparse page. It ensures that in inductor_output_code, every collective op issued has its respective wait op. If not, the missing_wait field will update accordingly. Updated tests. <img width="1237" height="372" alt="image" src="https://github.com/user-attachments/assets/08d3c1b5-13a0-41d5-897d-8361d204a278" />
1 parent fc687f8 commit 5938ed9

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

src/parsers.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow
753753
let call_re = Regex::new(
754754
r"torch\s*\.\s*ops\s*\.\s*_?c10d_functional\s*\.\s*([A-Za-z0-9_]+)\s*\.\s*default\s*\(",
755755
)?;
756-
let comment_re = Regex::new(r"(?m)^\s*#.*$|\s#[^0-9a-fA-F].*$|//.*$|(?s)/\*.*?\*/")?;
756+
let comment_re = Regex::new(r"(?m)#.*$|//.*$|(?s)/\*.*?\*/")?;
757757
let html_tag_re = Regex::new(r"(?s)<[^>]*>")?;
758758

759759
for &rank in rank_nums {
@@ -787,7 +787,7 @@ pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow
787787
.unwrap_or_default();
788788

789789
let mut report = crate::types::CollectivesParityReport {
790-
description: "Difference of # of collectives in scheduler and inductor output code"
790+
description: "Difference of # of collectives in scheduler and inductor output code and missing wait collectives"
791791
.to_string(),
792792
graphs: Vec::new(),
793793
};
@@ -848,11 +848,17 @@ pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow
848848
.replace_all(&html_tag_re.replace_all(&fs::read_to_string(code)?, ""), "")
849849
.into_owned();
850850
let mut code_counts: HashMap<&str, usize> = HashMap::new();
851+
let mut wait_count = 0usize;
851852
for cap in call_re.captures_iter(&code_clean) {
852-
if let Some(normalized) = normalize_op(cap.get(1).unwrap().as_str()) {
853+
let op = cap.get(1).unwrap().as_str();
854+
if op == "wait_tensor" {
855+
wait_count += 1;
856+
} else if let Some(normalized) = normalize_op(op) {
853857
*code_counts.entry(normalized).or_insert(0) += 1;
854858
}
855859
}
860+
let collective_total: usize = code_counts.values().sum();
861+
let missing_waits = collective_total.saturating_sub(wait_count);
856862

857863
// Compute offset over union of all detected ops
858864
let mut all_ops: std::collections::HashSet<&str> =
@@ -869,7 +875,7 @@ pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow
869875
})
870876
.sum();
871877

872-
if offset > 0 {
878+
if offset > 0 || missing_waits > 0 {
873879
let graph = compile_dir
874880
.file_name()
875881
.and_then(|n| n.to_str())
@@ -883,6 +889,7 @@ pub fn check_collectives_parity(out_path: &PathBuf, rank_nums: &[u32]) -> anyhow
883889
graph,
884890
compile_id,
885891
offset,
892+
missing_waits,
886893
});
887894
}
888895
}

src/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct GraphCollectivesParity {
5858
pub graph: String,
5959
pub compile_id: String,
6060
pub offset: usize,
61+
#[serde(default)]
62+
pub missing_waits: usize,
6163
}
6264

6365
#[derive(Debug, Serialize, Deserialize)]

tests/inputs/collectives_parity/dedicated_log_torch_trace_rank_0.log

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5217,7 +5217,7 @@ V0804 12:34:16.809000 1142857 torch/_inductor/graph.py:2390] {"inductor_output_c
52175217
# Topologically Sorted Source Nodes: [h_1, h_2, all_reduce_default], Original ATen: [aten.gelu, aten.native_layer_norm, _c10d_functional.all_reduce]
52185218
torch.ops._c10d_functional.all_reduce_.default(buf4, 'sum', '0')
52195219
# Topologically Sorted Source Nodes: [h_3], Original ATen: [_c10d_functional.wait_tensor]
5220-
torch.ops._c10d_functional.wait_tensor.default(buf4)
5220+
#torch.ops._c10d_functional.wait_tensor.default(buf4)
52215221
buf9 = empty_strided_cuda((1024, 1024), (1024, 1), torch.float16)
52225222
# Topologically Sorted Source Nodes: [h2], Original ATen: [aten.mm]
52235223
extern_kernels.mm(buf4, reinterpret_tensor(arg4_1, (1024, 1024), (1, 1024), 0), out=buf9)
@@ -5227,7 +5227,7 @@ V0804 12:34:16.809000 1142857 torch/_inductor/graph.py:2390] {"inductor_output_c
52275227
stream0 = get_raw_stream(0)
52285228
triton_poi_fused_all_gather_into_tensor_relu_1.run(buf10, 1048576, stream=stream0)
52295229
# Topologically Sorted Source Nodes: [h2_1, all_gather_into_tensor_default], Original ATen: [aten.relu, _c10d_functional.all_gather_into_tensor]
5230-
#buf11 = torch.ops._c10d_functional.all_gather_into_tensor.default(buf10, 2, '0')
5230+
buf11 = torch.ops._c10d_functional.all_gather_into_tensor.default(buf10, 2, '0')
52315231
assert_size_stride(buf11, (2048, 1024), (1024, 1), 'torch.ops._c10d_functional.all_gather_into_tensor.default')
52325232
assert_alignment(buf11, 16, 'torch.ops._c10d_functional.all_gather_into_tensor.default')
52335233
del buf4
@@ -5236,10 +5236,10 @@ V0804 12:34:16.809000 1142857 torch/_inductor/graph.py:2390] {"inductor_output_c
52365236
assert_size_stride(buf12, (1024, 1024), (1024, 1), 'torch.ops._c10d_functional.reduce_scatter_tensor.default')
52375237
assert_alignment(buf12, 16, 'torch.ops._c10d_functional.reduce_scatter_tensor.default')
52385238
# Topologically Sorted Source Nodes: [gathered], Original ATen: [_c10d_functional.wait_tensor]
5239-
torch.ops._c10d_functional.wait_tensor.default(buf11)
5239+
#torch.ops._c10d_functional.wait_tensor.default(buf11)
52405240
del buf10
52415241
# Topologically Sorted Source Nodes: [rs], Original ATen: [_c10d_functional.wait_tensor]
5242-
torch.ops._c10d_functional.wait_tensor.default(buf12)
5242+
#torch.ops._c10d_functional.wait_tensor.default(buf12)
52435243
del arg5_1
52445244
buf17 = empty_strided_cuda((2048, 1024), (1024, 1), torch.float16)
52455245
# Topologically Sorted Source Nodes: [g, rs_expanded, out], Original ATen: [aten.mul, aten.repeat, aten.add]

tests/integration_test.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2471,12 +2471,20 @@ fn test_collectives_parity_detects_mismatch() -> Result<(), Box<dyn std::error::
24712471
);
24722472
let rank_0_report: CollectivesParityReport =
24732473
serde_json::from_str(&fs::read_to_string(&rank_0_report_path)?)?;
2474-
// Expect single mismatch entry for graph -_0_1_0 with compile_id [0/1] and offset 1
2475-
assert_eq!(rank_0_report.graphs.len(), 1);
2476-
let g = &rank_0_report.graphs[0];
2477-
assert_eq!(g.graph, "-_0_1_0");
2474+
// Expect mismatches: graph -_0_1_0 has offset 1, graph -_0_0_0 missing wait
2475+
assert_eq!(rank_0_report.graphs.len(), 2);
2476+
let mut by_graph = HashMap::new();
2477+
for g in &rank_0_report.graphs {
2478+
by_graph.insert(g.graph.as_str(), g);
2479+
}
2480+
let g = by_graph.get("-_0_1_0").expect("missing -_0_1_0 entry");
24782481
assert_eq!(g.compile_id, "[0/1]");
2479-
assert_eq!(g.offset, 1);
2482+
assert_eq!(g.offset, 0);
2483+
assert_eq!(g.missing_waits, 3);
2484+
let g0 = by_graph.get("-_0_0_0").expect("missing -_0_0_0 entry");
2485+
assert_eq!(g0.compile_id, "[0/0]");
2486+
assert_eq!(g0.offset, 0);
2487+
assert_eq!(g0.missing_waits, 1);
24802488

24812489
Ok(())
24822490
}

0 commit comments

Comments
 (0)