Skip to content
Merged
36 changes: 32 additions & 4 deletions graph_net/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,17 @@ def convert_meta_classes_to_tensors(file_path):
data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1])
shape = attrs.get("shape", [])

if "min_val" in attrs and "max_val" in attrs:
if (
"min_val" in attrs
and "max_val" in attrs
and data_type
in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
):
min_val = attrs["min_val"]
max_val = attrs["max_val"]
# torch.randint's upper bound is exclusive, so add 1
Expand All @@ -242,9 +252,11 @@ def convert_meta_classes_to_tensors(file_path):
"mean": attrs.get("mean", 0.0),
"std": attrs.get("std", 1.0),
}
# Include min_val if present (for batch_norm running_var constraints)
# Include constraints if present (floats will be clamped in replay_tensor)
if "min_val" in attrs:
info_dict["min_val"] = attrs["min_val"]
if "max_val" in attrs:
info_dict["max_val"] = attrs["max_val"]

yield {
"info": info_dict,
Expand Down Expand Up @@ -280,12 +292,28 @@ def replay_tensor(info):
std = 0.1
if mean is None:
mean = 0
tensor = torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
# Handle std = 0 case to avoid generating identical values
if std == 0:
tensor = torch.full(size=shape, fill_value=mean, dtype=dtype, device=device)
else:
tensor = torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean

# Apply min_val constraint if present (for batch_norm running_var)
# Apply lower/upper bound constraints if present
if "min_val" in info["info"]:
min_val = info["info"]["min_val"]
tensor = torch.clamp(tensor, min=min_val)
if "max_val" in info["info"]:
max_val = info["info"]["max_val"]
tensor = torch.clamp(tensor, max=max_val)

# Additional numerical stability checks
if dtype.is_floating_point:
# Replace any inf or nan values with small random values
tensor = torch.where(
torch.isfinite(tensor), tensor, torch.randn_like(tensor) * 0.01
)
# Ensure no extremely large values
tensor = torch.clamp(tensor, min=-100.0, max=100.0)

return tensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def forward(
bool_1 = None
invert = ~getitem_1
getitem_1 = None
output_1 = output.masked_fill(invert, -inf)
output_1 = output.masked_fill(invert, -1e6)
output = invert = None
new_output = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output[(Ellipsis, slice(None, 7, None))] = output_1
setitem = new_output
Expand Down Expand Up @@ -95,10 +95,10 @@ def forward(
bool_2 = None
invert_1 = ~getitem_5
getitem_5 = None
output_3 = output_2.masked_fill(invert_1, -inf)
output_3 = output_2.masked_fill(invert_1, -1e6)
output_2 = invert_1 = None
new_output_1 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_1[(Ellipsis, slice(None, 7, None))] = output_3
setitem_1 = new_output_1
Expand Down Expand Up @@ -144,10 +144,10 @@ def forward(
bool_3 = None
invert_2 = ~getitem_9
getitem_9 = None
output_5 = output_4.masked_fill(invert_2, -inf)
output_5 = output_4.masked_fill(invert_2, -1e6)
output_4 = invert_2 = None
new_output_2 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_2[(Ellipsis, slice(None, 7, None))] = output_5
setitem_2 = new_output_2
Expand Down Expand Up @@ -193,10 +193,10 @@ def forward(
bool_4 = None
invert_3 = ~getitem_13
getitem_13 = None
output_7 = output_6.masked_fill(invert_3, -inf)
output_7 = output_6.masked_fill(invert_3, -1e6)
output_6 = invert_3 = None
new_output_3 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_3[(Ellipsis, slice(None, 7, None))] = output_7
setitem_3 = new_output_3
Expand Down Expand Up @@ -242,10 +242,10 @@ def forward(
bool_5 = None
invert_4 = ~getitem_17
getitem_17 = None
output_9 = output_8.masked_fill(invert_4, -inf)
output_9 = output_8.masked_fill(invert_4, -1e6)
output_8 = invert_4 = None
new_output_4 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_4[(Ellipsis, slice(None, 7, None))] = output_9
setitem_4 = new_output_4
Expand Down Expand Up @@ -294,10 +294,10 @@ def forward(
bool_6 = None
invert_5 = ~getitem_21
getitem_21 = None
output_11 = output_10.masked_fill(invert_5, -inf)
output_11 = output_10.masked_fill(invert_5, -1e6)
output_10 = invert_5 = None
new_output_5 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_5[(Ellipsis, slice(None, 7, None))] = output_11
setitem_5 = new_output_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class Program_weight_tensor_meta_L_stack0_encoder_last_hidden_state_text:
mean = 0.000
std = 1.000
data = None
min_val = -10.0
max_val = 10.0


class Program_weight_tensor_meta_L_stack0_intermediate_hidden_states:
Expand All @@ -16,6 +18,8 @@ class Program_weight_tensor_meta_L_stack0_intermediate_hidden_states:
mean = 0.000
std = 1.000
data = None
min_val = -10.0
max_val = 10.0


class Program_weight_tensor_meta_L_stack0_init_reference_points:
Expand All @@ -26,6 +30,8 @@ class Program_weight_tensor_meta_L_stack0_init_reference_points:
mean = 0.400
std = 0.296
data = None
min_val = 0.0
max_val = 1.0


class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
Expand All @@ -36,6 +42,8 @@ class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
mean = 0.400
std = 0.296
data = None
min_val = 0.0
max_val = 1.0


class Program_weight_tensor_meta_L_attention_mask_:
Expand All @@ -56,6 +64,8 @@ class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_lay
mean = -0.000
std = 0.020
data = None
min_val = -1.0
max_val = 1.0


class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_layers_modules_0_parameters_bias_:
Expand All @@ -68,6 +78,8 @@ class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_lay
mean = 0.000
std = 0.000
data = None
min_val = -1.0
max_val = 1.0


class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_layers_modules_1_parameters_weight_:
Expand All @@ -78,6 +90,8 @@ class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_lay
mean = 0.000
std = 0.020
data = None
min_val = -1.0
max_val = 1.0


class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_layers_modules_1_parameters_bias_:
Expand All @@ -90,6 +104,8 @@ class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_lay
mean = 0.000
std = 0.000
data = None
min_val = -1.0
max_val = 1.0


class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_layers_modules_2_parameters_weight_:
Expand All @@ -100,6 +116,8 @@ class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_lay
mean = 0.000
std = 0.000
data = None
min_val = -1.0
max_val = 1.0


class Program_weight_tensor_meta_L_self_modules_bbox_embed_modules_0_modules_layers_modules_2_parameters_bias_:
Expand Down
24 changes: 12 additions & 12 deletions samples/transformers-auto-model/fushh7_llmdet_swin_tiny_hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def forward(
bool_1 = None
invert = ~getitem_1
getitem_1 = None
output_1 = output.masked_fill(invert, -inf)
output_1 = output.masked_fill(invert, -1e6)
output = invert = None
new_output = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output[(Ellipsis, slice(None, 7, None))] = output_1
setitem = new_output
Expand Down Expand Up @@ -155,10 +155,10 @@ def forward(
bool_2 = None
invert_1 = ~getitem_5
getitem_5 = None
output_3 = output_2.masked_fill(invert_1, -inf)
output_3 = output_2.masked_fill(invert_1, -1e6)
output_2 = invert_1 = None
new_output_1 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_1[(Ellipsis, slice(None, 7, None))] = output_3
setitem_1 = new_output_1
Expand Down Expand Up @@ -204,10 +204,10 @@ def forward(
bool_3 = None
invert_2 = ~getitem_9
getitem_9 = None
output_5 = output_4.masked_fill(invert_2, -inf)
output_5 = output_4.masked_fill(invert_2, -1e6)
output_4 = invert_2 = None
new_output_2 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_2[(Ellipsis, slice(None, 7, None))] = output_5
setitem_2 = new_output_2
Expand Down Expand Up @@ -253,10 +253,10 @@ def forward(
bool_4 = None
invert_3 = ~getitem_13
getitem_13 = None
output_7 = output_6.masked_fill(invert_3, -inf)
output_7 = output_6.masked_fill(invert_3, -1e6)
output_6 = invert_3 = None
new_output_3 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_3[(Ellipsis, slice(None, 7, None))] = output_7
setitem_3 = new_output_3
Expand Down Expand Up @@ -302,10 +302,10 @@ def forward(
bool_5 = None
invert_4 = ~getitem_17
getitem_17 = None
output_9 = output_8.masked_fill(invert_4, -inf)
output_9 = output_8.masked_fill(invert_4, -1e6)
output_8 = invert_4 = None
new_output_4 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_4[(Ellipsis, slice(None, 7, None))] = output_9
setitem_4 = new_output_4
Expand Down Expand Up @@ -354,10 +354,10 @@ def forward(
bool_6 = None
invert_5 = ~getitem_21
getitem_21 = None
output_11 = output_10.masked_fill(invert_5, -inf)
output_11 = output_10.masked_fill(invert_5, -1e6)
output_10 = invert_5 = None
new_output_5 = torch.full(
(1, 900, 256), -inf, device=device(type="cuda", index=0)
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
)
new_output_5[(Ellipsis, slice(None, 7, None))] = output_11
setitem_5 = new_output_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Program_weight_tensor_meta_L_stack0_init_reference_points:
mean = 0.347
std = 0.339
data = None
min_val = 0.0
max_val = 1.0


class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
Expand All @@ -36,6 +38,8 @@ class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
mean = 0.347
std = 0.339
data = None
min_val = 0.0
max_val = 1.0


class Program_weight_tensor_meta_L_attention_mask_:
Expand Down