Skip to content

Commit d4e9483

Browse files
committed
layer_norm no-bias for Quack benchmark
1 parent 9e5d80e commit d4e9483

File tree

2 files changed

+86
-6
lines changed

2 files changed

+86
-6
lines changed

benchmarks/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"layer_norm": (
5151
"tritonbench.operators.layer_norm.operator",
5252
"examples.layer_norm",
53-
"layer_norm_fwd",
53+
"layer_norm_fwd_tritonbench",
5454
),
5555
"softmax": (
5656
"tritonbench.operators.softmax.operator",
@@ -166,7 +166,7 @@
166166
"layer_norm": (
167167
"tritonbench.operators.layer_norm.operator",
168168
"examples.layer_norm",
169-
"layer_norm_fwd",
169+
"layer_norm_fwd_tritonbench",
170170
),
171171
"jagged_softmax": (
172172
"tritonbench.operators.jagged_softmax.operator",

examples/layer_norm.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
# %%
1919
@helion.kernel
20-
def layer_norm_fwd(
20+
def layer_norm_fwd_with_bias(
2121
x: torch.Tensor,
2222
nomralized_shape: list[int],
2323
weight: torch.Tensor,
2424
bias: torch.Tensor,
2525
eps: float = 1e-5,
2626
) -> torch.Tensor:
2727
"""
28-
Performs 1D layer normalization on the input tensor using Helion.
28+
Performs 1D layer normalization with bias on the input tensor using Helion.
2929
Args:
3030
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3131
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
@@ -54,13 +54,66 @@ def layer_norm_fwd(
5454
return out
5555

5656

57+
@helion.kernel
58+
def layer_norm_fwd_no_bias(
59+
x: torch.Tensor,
60+
nomralized_shape: list[int],
61+
weight: torch.Tensor,
62+
eps: float = 1e-5,
63+
) -> torch.Tensor:
64+
"""
65+
Performs 1D layer normalization without bias on the input tensor using Helion.
66+
Args:
67+
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
68+
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
69+
weight (torch.Tensor): Learnable scale parameter of shape [dim].
70+
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
71+
Returns:
72+
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
73+
"""
74+
m, n = x.size()
75+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
76+
assert len(nomralized_shape) == 1, (
77+
"Helion layer norm only supports 1D layer norm currently"
78+
)
79+
assert nomralized_shape[0] == n, (
80+
f"normalized shape mismatch {nomralized_shape[0]} != {n}"
81+
)
82+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
83+
for tile_m in hl.tile(m):
84+
acc = x[tile_m, :].to(torch.float32)
85+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
86+
normalized = (acc - mean) * torch.rsqrt(var + eps)
87+
acc = normalized * (weight[:].to(torch.float32))
88+
out[tile_m, :] = acc
89+
return out
90+
91+
92+
def layer_norm_fwd_tritonbench(
93+
x: torch.Tensor,
94+
nomralized_shape: list[int],
95+
weight: torch.Tensor,
96+
bias: torch.Tensor | None,
97+
eps: float = 1e-5,
98+
) -> torch.Tensor:
99+
"""
100+
Wrapper function that dispatches to the appropriate layer normalization kernel.
101+
Compatible with tritonbench which may pass None for bias.
102+
"""
103+
if bias is None:
104+
return layer_norm_fwd_no_bias(x, nomralized_shape, weight, eps)
105+
else:
106+
return layer_norm_fwd_with_bias(x, nomralized_shape, weight, bias, eps)
107+
108+
57109
# %%
58110
def main() -> None:
59111
"""
60112
Main execution function for the layer normalization example.
61113
- Generates random input, weight, and bias tensors.
62114
- Runs the Helion layer normalization kernel and compares its output to PyTorch's
63115
built-in layer_norm function using the run_example utility.
116+
- Tests both with bias and without bias (no-bias mode).
64117
- Prints comparison results and checks for correctness within specified tolerances.
65118
"""
66119
batch_size = 32
@@ -70,15 +123,42 @@ def main() -> None:
70123
weight = torch.randn([dim], device=device, dtype=torch.float16)
71124
bias = torch.randn([dim], device=device, dtype=torch.float16)
72125
eps = 1e-4
126+
127+
# Test with bias
128+
print("Testing layer_norm WITH bias:")
73129
run_example(
74-
layer_norm_fwd,
130+
layer_norm_fwd_with_bias,
75131
torch.nn.functional.layer_norm,
76132
(x, [dim], weight, bias, eps),
77-
kernel_name="helion",
133+
kernel_name="helion_with_bias",
134+
baseline_name="torch",
135+
rtol=1e-3,
136+
atol=1e-3,
137+
)
138+
139+
# Test without bias (no-bias mode)
140+
print("\nTesting layer_norm WITHOUT bias (no-bias mode):")
141+
run_example(
142+
layer_norm_fwd_no_bias,
143+
lambda x, shape, w, e: torch.nn.functional.layer_norm(x, shape, w, None, e),
144+
(x, [dim], weight, eps),
145+
kernel_name="helion_no_bias",
78146
baseline_name="torch",
79147
rtol=1e-3,
80148
atol=1e-3,
81149
)
150+
151+
# Test wrapper function with bias
152+
print("\nTesting wrapper function WITH bias:")
153+
result_with_bias = layer_norm_fwd_tritonbench(x, [dim], weight, bias, eps)
154+
expected_with_bias = torch.nn.functional.layer_norm(x, [dim], weight, bias, eps)
155+
print(f" Wrapper with bias matches torch: {torch.allclose(result_with_bias, expected_with_bias, rtol=1e-3, atol=1e-3)}")
156+
157+
# Test wrapper function without bias
158+
print("\nTesting wrapper function WITHOUT bias:")
159+
result_no_bias = layer_norm_fwd_tritonbench(x, [dim], weight, None, eps)
160+
expected_no_bias = torch.nn.functional.layer_norm(x, [dim], weight, None, eps)
161+
print(f" Wrapper without bias matches torch: {torch.allclose(result_no_bias, expected_no_bias, rtol=1e-3, atol=1e-3)}")
82162

83163

84164
# %%

0 commit comments

Comments
 (0)