17
17
18
18
# %%
19
19
@helion .kernel
20
- def layer_norm_fwd (
20
+ def layer_norm_fwd_with_bias (
21
21
x : torch .Tensor ,
22
22
nomralized_shape : list [int ],
23
23
weight : torch .Tensor ,
24
24
bias : torch .Tensor ,
25
25
eps : float = 1e-5 ,
26
26
) -> torch .Tensor :
27
27
"""
28
- Performs 1D layer normalization on the input tensor using Helion.
28
+ Performs 1D layer normalization with bias on the input tensor using Helion.
29
29
Args:
30
30
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
31
31
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
@@ -54,13 +54,66 @@ def layer_norm_fwd(
54
54
return out
55
55
56
56
57
+ @helion .kernel
58
+ def layer_norm_fwd_no_bias (
59
+ x : torch .Tensor ,
60
+ nomralized_shape : list [int ],
61
+ weight : torch .Tensor ,
62
+ eps : float = 1e-5 ,
63
+ ) -> torch .Tensor :
64
+ """
65
+ Performs 1D layer normalization without bias on the input tensor using Helion.
66
+ Args:
67
+ x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
68
+ nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
69
+ weight (torch.Tensor): Learnable scale parameter of shape [dim].
70
+ eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
71
+ Returns:
72
+ torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
73
+ """
74
+ m , n = x .size ()
75
+ assert weight .size (0 ) == n , f"weight size mismatch { weight .size (0 )} != { m } "
76
+ assert len (nomralized_shape ) == 1 , (
77
+ "Helion layer norm only supports 1D layer norm currently"
78
+ )
79
+ assert nomralized_shape [0 ] == n , (
80
+ f"normalized shape mismatch { nomralized_shape [0 ]} != { n } "
81
+ )
82
+ out = torch .empty ([m , n ], dtype = torch .float16 , device = x .device )
83
+ for tile_m in hl .tile (m ):
84
+ acc = x [tile_m , :].to (torch .float32 )
85
+ var , mean = torch .var_mean (acc , dim = - 1 , keepdim = True , correction = 0 )
86
+ normalized = (acc - mean ) * torch .rsqrt (var + eps )
87
+ acc = normalized * (weight [:].to (torch .float32 ))
88
+ out [tile_m , :] = acc
89
+ return out
90
+
91
+
92
+ def layer_norm_fwd_tritonbench (
93
+ x : torch .Tensor ,
94
+ nomralized_shape : list [int ],
95
+ weight : torch .Tensor ,
96
+ bias : torch .Tensor | None ,
97
+ eps : float = 1e-5 ,
98
+ ) -> torch .Tensor :
99
+ """
100
+ Wrapper function that dispatches to the appropriate layer normalization kernel.
101
+ Compatible with tritonbench which may pass None for bias.
102
+ """
103
+ if bias is None :
104
+ return layer_norm_fwd_no_bias (x , nomralized_shape , weight , eps )
105
+ else :
106
+ return layer_norm_fwd_with_bias (x , nomralized_shape , weight , bias , eps )
107
+
108
+
57
109
# %%
58
110
def main () -> None :
59
111
"""
60
112
Main execution function for the layer normalization example.
61
113
- Generates random input, weight, and bias tensors.
62
114
- Runs the Helion layer normalization kernel and compares its output to PyTorch's
63
115
built-in layer_norm function using the run_example utility.
116
+ - Tests both with bias and without bias (no-bias mode).
64
117
- Prints comparison results and checks for correctness within specified tolerances.
65
118
"""
66
119
batch_size = 32
@@ -70,15 +123,42 @@ def main() -> None:
70
123
weight = torch .randn ([dim ], device = device , dtype = torch .float16 )
71
124
bias = torch .randn ([dim ], device = device , dtype = torch .float16 )
72
125
eps = 1e-4
126
+
127
+ # Test with bias
128
+ print ("Testing layer_norm WITH bias:" )
73
129
run_example (
74
- layer_norm_fwd ,
130
+ layer_norm_fwd_with_bias ,
75
131
torch .nn .functional .layer_norm ,
76
132
(x , [dim ], weight , bias , eps ),
77
- kernel_name = "helion" ,
133
+ kernel_name = "helion_with_bias" ,
134
+ baseline_name = "torch" ,
135
+ rtol = 1e-3 ,
136
+ atol = 1e-3 ,
137
+ )
138
+
139
+ # Test without bias (no-bias mode)
140
+ print ("\n Testing layer_norm WITHOUT bias (no-bias mode):" )
141
+ run_example (
142
+ layer_norm_fwd_no_bias ,
143
+ lambda x , shape , w , e : torch .nn .functional .layer_norm (x , shape , w , None , e ),
144
+ (x , [dim ], weight , eps ),
145
+ kernel_name = "helion_no_bias" ,
78
146
baseline_name = "torch" ,
79
147
rtol = 1e-3 ,
80
148
atol = 1e-3 ,
81
149
)
150
+
151
+ # Test wrapper function with bias
152
+ print ("\n Testing wrapper function WITH bias:" )
153
+ result_with_bias = layer_norm_fwd_tritonbench (x , [dim ], weight , bias , eps )
154
+ expected_with_bias = torch .nn .functional .layer_norm (x , [dim ], weight , bias , eps )
155
+ print (f" Wrapper with bias matches torch: { torch .allclose (result_with_bias , expected_with_bias , rtol = 1e-3 , atol = 1e-3 )} " )
156
+
157
+ # Test wrapper function without bias
158
+ print ("\n Testing wrapper function WITHOUT bias:" )
159
+ result_no_bias = layer_norm_fwd_tritonbench (x , [dim ], weight , None , eps )
160
+ expected_no_bias = torch .nn .functional .layer_norm (x , [dim ], weight , None , eps )
161
+ print (f" Wrapper without bias matches torch: { torch .allclose (result_no_bias , expected_no_bias , rtol = 1e-3 , atol = 1e-3 )} " )
82
162
83
163
84
164
# %%
0 commit comments