66import neat
77import pytest
88import torch
9+ from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
910
1011# allow imports from repo root
1112sys .path .insert (0 , str (pathlib .Path (__file__ ).resolve ().parents [1 ]))
@@ -27,6 +28,105 @@ def make_config():
2728 )
2829
2930
31+ def get_node_signature (node ):
32+ # simple signature includes kind (operator name), types of inputs, and output type
33+ # TODO: for robust comparison, also need to compare attributes and potentially canonicalize constant values
34+ input_kinds = [inp .node ().kind () for inp in node .inputs ()]
35+
36+ # TODO: finish
37+ attributes = {}
38+ if node .kind () == "prim::Constant" :
39+ if node .hasAttribute ("value" ):
40+ attributes ["value" ] = node .t ("value" )
41+ elif node .hasAttribute ("i" ):
42+ attributes ["value" ] = node .i ("i" )
43+ elif node .hasAttribute ("f" ):
44+ attributes ["value" ] = node .f ("f" )
45+ # Add more attribute types as needed
46+
47+ return (node .kind (), tuple (input_kinds ), node .output ().type (), tuple (sorted (attributes .items ())))
48+
49+
50+ def compare_jit_graphs_structural (original : torch .jit .ScriptModule , rebuilt : torch .jit .ScriptModule ) -> bool :
51+ original_inputs = list (original .graph .inputs ())
52+ rebuilt_inputs = list (rebuilt .graph .inputs ())
53+ original_outputs = list (original .graph .outputs ())
54+ rebuilt_outputs = list (rebuilt .graph .outputs ())
55+ if len (original_inputs ) != len (rebuilt_inputs ) or len (original_outputs ) != len (rebuilt_outputs ):
56+ print (
57+ f"Input/output counts differ: original.graph inputs={ len (original_inputs )} , outputs={ len (original_outputs )} vs rebuilt inputs={ len (rebuilt_inputs )} , outputs={ len (rebuilt_outputs )} " ,
58+ file = sys .stderr ,
59+ )
60+ return False
61+
62+ # default iterator for graph.nodes() is typically a topological sort
63+ original_nodes = list (original .graph .nodes ())
64+ rebuilt_nodes = list (rebuilt .graph .nodes ())
65+
66+ if len (original_nodes ) != len (rebuilt_nodes ):
67+ print (
68+ f"Number of nodes differ: original.graph has { len (original_nodes )} nodes, rebuilt has { len (rebuilt_nodes )} nodes" ,
69+ file = sys .stderr ,
70+ )
71+ return False
72+
73+ # create mapping from nodes to canonical representation based on signature + inputs
74+ original_node_map = {}
75+ rebuilt_node_map = {}
76+ for i , (original_node , rebuilt_node ) in enumerate (zip (original_nodes , rebuilt_nodes )):
77+ signature1 = get_node_signature (original_node )
78+ signature2 = get_node_signature (rebuilt_node )
79+
80+ if signature1 != signature2 :
81+ print (f"Signatures differ at node { i } :" , file = sys .stderr )
82+ print (f" original.graph Node Kind: { original_node .kind ()} " , file = sys .stderr )
83+ print (f" rebuilt Node Kind: { rebuilt_node .kind ()} " , file = sys .stderr )
84+ # TODO: add more detailed diffing here
85+ return False
86+
87+ # assumes a consistent order of inputs and that corresponding inputs have corresponding nodes
88+ for input_idx , (original_input_val , rebeuilt_input_val ) in enumerate (
89+ zip (original_node .inputs (), rebuilt_node .inputs ())
90+ ):
91+ if original_input_val .node ().kind () != rebeuilt_input_val .node ().kind ():
92+ print (f"Input kind differs for node { i } , input { input_idx } " , file = sys .stderr )
93+ return False
94+ # TODO: need to further compare value properties if they are constants or recursively
95+ # check if the input nodes themselves are structurally equivalent up to that point
96+
97+ original_params = dict (original .named_parameters ())
98+ rebuilt_params = dict (rebuilt .named_parameters ())
99+ if len (original_params ) != len (rebuilt_params ):
100+ print ("Parameter counts differ" , file = sys .stderr )
101+ return False
102+ for name , original_param in original_params .items ():
103+ if name not in rebuilt_params :
104+ print (f"Parameter '{ name } ' missing in rebuilt graph" , file = sys .stderr )
105+ return False
106+ rebuilt_param = rebuilt_params [name ]
107+ if not torch .equal (original_param , rebuilt_param ):
108+ print (f"Parameter '{ name } ' values differ" , file = sys .stderr )
109+ return False
110+
111+ if not compare_custom_data (original , rebuilt ):
112+ print ("Custom data attributes differ" , file = sys .stderr )
113+ return False
114+
115+ return True
116+
117+
118+ def compare_custom_data (original : torch .jit .ScriptModule , rebuilt : torch .jit .ScriptModule ) -> bool :
119+ if hasattr (original , "node_types" ) and hasattr (rebuilt , "node_types" ):
120+ if original .node_types != rebuilt .node_types :
121+ print ("node_types differ" , file = sys .stderr )
122+ return False
123+ if hasattr (original , "edge_index" ) and hasattr (rebuilt , "edge_index" ):
124+ if not torch .equal (original .edge_index , rebuilt .edge_index ):
125+ print ("edge_index differ" , file = sys .stderr )
126+ return False
127+ return True
128+
129+
30130@pytest .mark .parametrize ("pt_path" , glob .glob (os .path .join ("computation_graphs" , "optimizers" , "*.pt" )))
31131def test_graph_builder_rebuilds_pt (pt_path ):
32132 original = torch .jit .load (pt_path )
@@ -51,11 +151,5 @@ def test_graph_builder_rebuilds_pt(pt_path):
51151 assert len (list (rebuilt .parameters ())) == len (expected_edges )
52152 assert len (rebuilt .node_types ) == len (data .node_types )
53153
54- # Verify that the rebuilt computation graph is identical to the original
55- if str (rebuilt .graph ) != str (original .graph ):
56- print ("Original graph:\n " , original .graph )
57- print ("Rebuilt graph:\n " , rebuilt .graph )
58- assert str (rebuilt .graph ) == str (original .graph ), (
59- "\n Original graph:\n " + str (original .graph ) +
60- "\n Rebuilt graph:\n " + str (rebuilt .graph )
61- )
154+ # Verify that the rebuilt computation graph is structurally identical to the original
155+ assert compare_jit_graphs_structural (rebuilt , original )
0 commit comments