Skip to content

Commit eb99525

Browse files
committed
Always generate inputs beforehand, with fixed initial seeds
1 parent eaa1848 commit eb99525

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

benchmarks/run.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from typing import Callable
3030
import time
3131

32+
import torch
33+
3234
# Maps tritonbench op names to Helion kernel examples
3335
# Can map to a single kernel or a list of kernel variants
3436
# Format options:
@@ -436,20 +438,31 @@ def run_kernel_variants(
436438
register_benchmark,
437439
)
438440

439-
# Inject only_shapes filter if provided
441+
# Always extract all inputs beforehand
442+
# Override the get_input_iter method for the operator class
443+
original_get_input_iter = Operator.get_input_iter
444+
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None
445+
446+
# Create a list to store all inputs
447+
all_inputs = []
448+
449+
# Collect all inputs
450+
torch.manual_seed(42)
451+
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
452+
for inputs in original_get_input_iter(temp_operator):
453+
# Set random seed for reproducibility
454+
torch.manual_seed(42)
455+
all_inputs.append(inputs)
456+
457+
# If only_shapes is provided, filter the inputs
440458
if only_shapes:
441459
print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)
442460

443-
# Override the get_input_iter method for the operator class
444-
original_get_input_iter = Operator.get_input_iter
445-
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None
446-
447-
# Create a list to store filtered inputs and their shapes
461+
# Create a list to store filtered inputs
448462
filtered_inputs = []
449463

450-
# First, collect all inputs that match the shape filter
451-
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
452-
for inputs in original_get_input_iter(temp_operator):
464+
# Filter inputs that match the shape filter
465+
for inputs in all_inputs:
453466
# Get the shape value for this input
454467
shape_value = None
455468

@@ -483,21 +496,28 @@ def run_kernel_variants(
483496
filtered_inputs.append(inputs)
484497
print(f" Including shape: {shape_value}", file=sys.stderr)
485498

486-
del temp_operator # Clean up temporary operator
487-
488499
if not filtered_inputs:
489500
print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)
490501

491-
def filtered_get_input_iter(self):
492-
"""Custom input iterator that only yields filtered shapes."""
493-
for inputs in filtered_inputs:
494-
yield inputs
495-
496-
# Monkey-patch the operator class
497-
Operator.get_input_iter = filtered_get_input_iter
498-
499-
# Also override _available_num_inputs for proper sharding support
500-
Operator._available_num_inputs = len(filtered_inputs)
502+
# Use filtered inputs instead of all inputs
503+
inputs_to_use = filtered_inputs
504+
else:
505+
# Use all inputs
506+
inputs_to_use = all_inputs
507+
508+
del temp_operator # Clean up temporary operator
509+
510+
# Create a new input iterator function
511+
def new_get_input_iter(self):
512+
"""Custom input iterator that yields pre-collected inputs."""
513+
for inputs in inputs_to_use:
514+
yield inputs
515+
516+
# Monkey-patch the operator class
517+
Operator.get_input_iter = new_get_input_iter
518+
519+
# Also override _available_num_inputs for proper sharding support
520+
Operator._available_num_inputs = len(inputs_to_use)
501521

502522
# Register all variants as separate methods
503523
for module_path, func_name in variants:

0 commit comments

Comments
 (0)