Skip to content

Commit c06aa71

Browse files
committed
use structural comparison in topological order for rebuilt computation graph equality tests
1 parent 7ed87b2 commit c06aa71

File tree

3 files changed

+107
-17
lines changed

3 files changed

+107
-17
lines changed

graph_builder.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
"""
2-
Generated by ChatGPT
3-
"""
41
from typing import Dict, List, Tuple
52

63
import torch
@@ -147,9 +144,9 @@ def forward(
147144
self,
148145
loss: torch.Tensor,
149146
prev_loss: torch.Tensor,
150-
named_params: List[Tuple[str, torch.Tensor]],
147+
named_parameters: List[Tuple[str, torch.Tensor]],
151148
) -> Dict[str, torch.Tensor]:
152-
params = [p for _, p in named_params]
149+
params = [p for _, p in named_parameters]
153150
all_inputs = [loss, prev_loss] + params
154151
features = torch.stack(all_inputs, 0)
155152

@@ -198,8 +195,6 @@ def rebuild_and_script(graph_dict, config, key) -> DynamicOptimizerModule:
198195

199196
# --- build a Python module and script it ---
200197
if genome.connections:
201-
module = DynamicOptimizerModule(
202-
genome, config.input_keys, config.output_keys, graph_dict
203-
)
198+
module = DynamicOptimizerModule(genome, config.input_keys, config.output_keys, graph_dict)
204199
return torch.jit.script(module)
205200
return None

population.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def evaluate_optimizer(self, optimizer, model, task, steps=10):
292292
task: The task on which to evaluate the optimizer.
293293
steps: Number of update iterations.
294294
"""
295-
# TODO: find way to correct for time improvements that are solely due to RAM cache tiers
295+
# TODO: clear all levels of RAM caches in between every run to create fair starting point
296+
# for comparison
296297
tracemalloc.start()
297298
start = time.perf_counter()
298299
prev_metrics_values = torch.tensor([0.0] * len(task.metrics))

tests/test_graph_builder.py

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import neat
77
import pytest
88
import torch
9+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
910

1011
# allow imports from repo root
1112
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1]))
@@ -27,6 +28,105 @@ def make_config():
2728
)
2829

2930

31+
def get_node_signature(node):
32+
# simple signature includes kind (operator name), types of inputs, and output type
33+
# TODO: for robust comparison, also need to compare attributes and potentially canonicalize constant values
34+
input_kinds = [inp.node().kind() for inp in node.inputs()]
35+
36+
# TODO: finish
37+
attributes = {}
38+
if node.kind() == "prim::Constant":
39+
if node.hasAttribute("value"):
40+
attributes["value"] = node.t("value")
41+
elif node.hasAttribute("i"):
42+
attributes["value"] = node.i("i")
43+
elif node.hasAttribute("f"):
44+
attributes["value"] = node.f("f")
45+
# Add more attribute types as needed
46+
47+
return (node.kind(), tuple(input_kinds), node.output().type(), tuple(sorted(attributes.items())))
48+
49+
50+
def compare_jit_graphs_structural(original: torch.jit.ScriptModule, rebuilt: torch.jit.ScriptModule) -> bool:
51+
original_inputs = list(original.graph.inputs())
52+
rebuilt_inputs = list(rebuilt.graph.inputs())
53+
original_outputs = list(original.graph.outputs())
54+
rebuilt_outputs = list(rebuilt.graph.outputs())
55+
if len(original_inputs) != len(rebuilt_inputs) or len(original_outputs) != len(rebuilt_outputs):
56+
print(
57+
f"Input/output counts differ: original.graph inputs={len(original_inputs)}, outputs={len(original_outputs)} vs rebuilt inputs={len(rebuilt_inputs)}, outputs={len(rebuilt_outputs)}",
58+
file=sys.stderr,
59+
)
60+
return False
61+
62+
# default iterator for graph.nodes() is typically a topological sort
63+
original_nodes = list(original.graph.nodes())
64+
rebuilt_nodes = list(rebuilt.graph.nodes())
65+
66+
if len(original_nodes) != len(rebuilt_nodes):
67+
print(
68+
f"Number of nodes differ: original.graph has {len(original_nodes)} nodes, rebuilt has {len(rebuilt_nodes)} nodes",
69+
file=sys.stderr,
70+
)
71+
return False
72+
73+
# create mapping from nodes to canonical representation based on signature + inputs
74+
original_node_map = {}
75+
rebuilt_node_map = {}
76+
for i, (original_node, rebuilt_node) in enumerate(zip(original_nodes, rebuilt_nodes)):
77+
signature1 = get_node_signature(original_node)
78+
signature2 = get_node_signature(rebuilt_node)
79+
80+
if signature1 != signature2:
81+
print(f"Signatures differ at node {i}:", file=sys.stderr)
82+
print(f" original.graph Node Kind: {original_node.kind()}", file=sys.stderr)
83+
print(f" rebuilt Node Kind: {rebuilt_node.kind()}", file=sys.stderr)
84+
# TODO: add more detailed diffing here
85+
return False
86+
87+
# assumes a consistent order of inputs and that corresponding inputs have corresponding nodes
88+
for input_idx, (original_input_val, rebeuilt_input_val) in enumerate(
89+
zip(original_node.inputs(), rebuilt_node.inputs())
90+
):
91+
if original_input_val.node().kind() != rebeuilt_input_val.node().kind():
92+
print(f"Input kind differs for node {i}, input {input_idx}", file=sys.stderr)
93+
return False
94+
# TODO: need to further compare value properties if they are constants or recursively
95+
# check if the input nodes themselves are structurally equivalent up to that point
96+
97+
original_params = dict(original.named_parameters())
98+
rebuilt_params = dict(rebuilt.named_parameters())
99+
if len(original_params) != len(rebuilt_params):
100+
print("Parameter counts differ", file=sys.stderr)
101+
return False
102+
for name, original_param in original_params.items():
103+
if name not in rebuilt_params:
104+
print(f"Parameter '{name}' missing in rebuilt graph", file=sys.stderr)
105+
return False
106+
rebuilt_param = rebuilt_params[name]
107+
if not torch.equal(original_param, rebuilt_param):
108+
print(f"Parameter '{name}' values differ", file=sys.stderr)
109+
return False
110+
111+
if not compare_custom_data(original, rebuilt):
112+
print("Custom data attributes differ", file=sys.stderr)
113+
return False
114+
115+
return True
116+
117+
118+
def compare_custom_data(original: torch.jit.ScriptModule, rebuilt: torch.jit.ScriptModule) -> bool:
119+
if hasattr(original, "node_types") and hasattr(rebuilt, "node_types"):
120+
if original.node_types != rebuilt.node_types:
121+
print("node_types differ", file=sys.stderr)
122+
return False
123+
if hasattr(original, "edge_index") and hasattr(rebuilt, "edge_index"):
124+
if not torch.equal(original.edge_index, rebuilt.edge_index):
125+
print("edge_index differ", file=sys.stderr)
126+
return False
127+
return True
128+
129+
30130
@pytest.mark.parametrize("pt_path", glob.glob(os.path.join("computation_graphs", "optimizers", "*.pt")))
31131
def test_graph_builder_rebuilds_pt(pt_path):
32132
original = torch.jit.load(pt_path)
@@ -51,11 +151,5 @@ def test_graph_builder_rebuilds_pt(pt_path):
51151
assert len(list(rebuilt.parameters())) == len(expected_edges)
52152
assert len(rebuilt.node_types) == len(data.node_types)
53153

54-
# Verify that the rebuilt computation graph is identical to the original
55-
if str(rebuilt.graph) != str(original.graph):
56-
print("Original graph:\n", original.graph)
57-
print("Rebuilt graph:\n", rebuilt.graph)
58-
assert str(rebuilt.graph) == str(original.graph), (
59-
"\nOriginal graph:\n" + str(original.graph) +
60-
"\nRebuilt graph:\n" + str(rebuilt.graph)
61-
)
154+
# Verify that the rebuilt computation graph is structurally identical to the original
155+
assert compare_jit_graphs_structural(rebuilt, original)

0 commit comments

Comments
 (0)