|
29 | 29 | from typing import Callable
|
30 | 30 | import time
|
31 | 31 |
|
| 32 | +import torch |
| 33 | + |
32 | 34 | # Maps tritonbench op names to Helion kernel examples
|
33 | 35 | # Can map to a single kernel or a list of kernel variants
|
34 | 36 | # Format options:
|
@@ -436,20 +438,31 @@ def run_kernel_variants(
|
436 | 438 | register_benchmark,
|
437 | 439 | )
|
438 | 440 |
|
439 |
| - # Inject only_shapes filter if provided |
| 441 | + # Always extract all inputs beforehand |
| 442 | + # Override the get_input_iter method for the operator class |
| 443 | + original_get_input_iter = Operator.get_input_iter |
| 444 | + original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None |
| 445 | + |
| 446 | + # Create a list to store all inputs |
| 447 | + all_inputs = [] |
| 448 | + |
| 449 | + # Collect all inputs |
| 450 | + torch.manual_seed(42) |
| 451 | + temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args) |
| 452 | + for inputs in original_get_input_iter(temp_operator): |
| 453 | + # Set random seed for reproducibility |
| 454 | + torch.manual_seed(42) |
| 455 | + all_inputs.append(inputs) |
| 456 | + |
| 457 | + # If only_shapes is provided, filter the inputs |
440 | 458 | if only_shapes:
|
441 | 459 | print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)
|
442 | 460 |
|
443 |
| - # Override the get_input_iter method for the operator class |
444 |
| - original_get_input_iter = Operator.get_input_iter |
445 |
| - original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None |
446 |
| - |
447 |
| - # Create a list to store filtered inputs and their shapes |
| 461 | + # Create a list to store filtered inputs |
448 | 462 | filtered_inputs = []
|
449 | 463 |
|
450 |
| - # First, collect all inputs that match the shape filter |
451 |
| - temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args) |
452 |
| - for inputs in original_get_input_iter(temp_operator): |
| 464 | + # Filter inputs that match the shape filter |
| 465 | + for inputs in all_inputs: |
453 | 466 | # Get the shape value for this input
|
454 | 467 | shape_value = None
|
455 | 468 |
|
@@ -483,21 +496,28 @@ def run_kernel_variants(
|
483 | 496 | filtered_inputs.append(inputs)
|
484 | 497 | print(f" Including shape: {shape_value}", file=sys.stderr)
|
485 | 498 |
|
486 |
| - del temp_operator # Clean up temporary operator |
487 |
| - |
488 | 499 | if not filtered_inputs:
|
489 | 500 | print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)
|
490 | 501 |
|
491 |
| - def filtered_get_input_iter(self): |
492 |
| - """Custom input iterator that only yields filtered shapes.""" |
493 |
| - for inputs in filtered_inputs: |
494 |
| - yield inputs |
495 |
| - |
496 |
| - # Monkey-patch the operator class |
497 |
| - Operator.get_input_iter = filtered_get_input_iter |
498 |
| - |
499 |
| - # Also override _available_num_inputs for proper sharding support |
500 |
| - Operator._available_num_inputs = len(filtered_inputs) |
| 502 | + # Use filtered inputs instead of all inputs |
| 503 | + inputs_to_use = filtered_inputs |
| 504 | + else: |
| 505 | + # Use all inputs |
| 506 | + inputs_to_use = all_inputs |
| 507 | + |
| 508 | + del temp_operator # Clean up temporary operator |
| 509 | + |
| 510 | + # Create a new input iterator function |
| 511 | + def new_get_input_iter(self): |
| 512 | + """Custom input iterator that yields pre-collected inputs.""" |
| 513 | + for inputs in inputs_to_use: |
| 514 | + yield inputs |
| 515 | + |
| 516 | + # Monkey-patch the operator class |
| 517 | + Operator.get_input_iter = new_get_input_iter |
| 518 | + |
| 519 | + # Also override _available_num_inputs for proper sharding support |
| 520 | + Operator._available_num_inputs = len(inputs_to_use) |
501 | 521 |
|
502 | 522 | # Register all variants as separate methods
|
503 | 523 | for module_path, func_name in variants:
|
|
0 commit comments