Skip to content

Commit 435b56f

Browse files
committed
[example] fused_linear_jsd
1 parent 01c831e commit 435b56f

File tree

5 files changed

+249
-10
lines changed

5 files changed

+249
-10
lines changed

benchmarks/run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@
109109
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
110110
],
111111
),
112+
"fused_linear_jsd": (
113+
"tritonbench.operators.fused_linear_jsd.operator",
114+
"examples.fused_linear_jsd",
115+
"fused_linear_jsd_fwd_tritonbench",
116+
),
112117
}
113118

114119

@@ -405,6 +410,8 @@ def helion_method(
405410
attr.settings.force_autotune = True
406411
attr.settings.static_shape = True # pyright: ignore[reportAttributeAccessIssue]
407412

413+
kfunc._self = self # pyright: ignore[reportFunctionMemberAccess]
414+
408415
def _inner() -> Callable[..., Any] | object:
409416
# BENCHMARK HOT PATH, do not add any new logic here
410417
result = kfunc(*args, **kwargs)

examples/fused_linear_jsd.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel()
11+
def fused_linear_jsd_fwd(
12+
beta: float,
13+
ignore_index: int,
14+
temperature: float,
15+
student_weight: torch.Tensor,
16+
teacher_weight: torch.Tensor,
17+
student_input: torch.Tensor,
18+
teacher_input: torch.Tensor,
19+
) -> torch.Tensor:
20+
student_logits = student_input @ student_weight.T
21+
teacher_logits = teacher_input @ teacher_weight.T
22+
loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
23+
for batch in hl.tile(student_logits.shape[0]):
24+
student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1)
25+
teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1)
26+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
27+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
28+
m = torch.exp(student_prob) + beta * (
29+
torch.exp(teacher_prob) - torch.exp(student_prob)
30+
)
31+
teacher_div = torch.nn.functional.kl_div(
32+
torch.log(m), teacher_prob, reduction="none", log_target=True
33+
).sum(dim=-1)
34+
student_div = torch.nn.functional.kl_div(
35+
torch.log(m), student_prob, reduction="none", log_target=True
36+
).sum(dim=-1)
37+
batch_loss = student_div + beta * (teacher_div - student_div)
38+
loss[batch] = batch_loss
39+
return (loss / student_logits.shape[0]).sum()
40+
41+
42+
# %%
43+
# Benchmark Entry Point Function
44+
# -------------------
45+
def fused_linear_jsd_fwd_tritonbench(
46+
student_input: torch.Tensor,
47+
teacher_input: torch.Tensor,
48+
label: torch.Tensor | None = None,
49+
) -> torch.Tensor:
50+
assert label is None
51+
baseline_op = fused_linear_jsd_fwd_tritonbench._self.baseline_op # pyright: ignore[reportFunctionMemberAccess]
52+
beta = baseline_op.jsd.beta
53+
ignore_index = baseline_op.jsd.ignore_index
54+
temperature = baseline_op.temperature
55+
student_weight = baseline_op.student_lin.weight
56+
teacher_weight = baseline_op.teacher_lin.weight
57+
return fused_linear_jsd_fwd(
58+
beta,
59+
ignore_index,
60+
temperature,
61+
student_weight,
62+
teacher_weight,
63+
student_input,
64+
teacher_input,
65+
)
66+
67+
68+
# %%
69+
# Reference Implementation
70+
# --------------------
71+
def fused_linear_jsd_pytorch(
72+
beta: float,
73+
ignore_index: int,
74+
temperature: float,
75+
student_weight: torch.Tensor,
76+
teacher_weight: torch.Tensor,
77+
student_input: torch.Tensor,
78+
teacher_input: torch.Tensor,
79+
) -> torch.Tensor:
80+
student_logits = student_input @ student_weight.T
81+
teacher_logits = teacher_input @ teacher_weight.T
82+
student_prob = torch.log_softmax(student_logits / temperature, dim=-1)
83+
teacher_prob = torch.log_softmax(teacher_logits / temperature, dim=-1)
84+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
85+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
86+
m = torch.exp(student_prob) + beta * (
87+
torch.exp(teacher_prob) - torch.exp(student_prob)
88+
)
89+
teacher_div = torch.nn.functional.kl_div(
90+
torch.log(m), teacher_prob, reduction="none", log_target=True
91+
).sum(dim=-1)
92+
student_div = torch.nn.functional.kl_div(
93+
torch.log(m), student_prob, reduction="none", log_target=True
94+
).sum(dim=-1)
95+
loss = student_div + beta * (teacher_div - student_div)
96+
return (loss / student_logits.shape[0]).sum()
97+
98+
99+
# %%
100+
# Verification Function
101+
# -------------------
102+
def check(m: int, n: int, k: int) -> None:
103+
student_input = torch.rand([m, n], device="cuda", dtype=torch.float)
104+
teacher_input = torch.rand([m, n], device="cuda", dtype=torch.float)
105+
student_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
106+
teacher_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
107+
run_example(
108+
fused_linear_jsd_fwd,
109+
fused_linear_jsd_pytorch,
110+
(0.5, 1, 1.0, student_weight, teacher_weight, student_input, teacher_input),
111+
)
112+
113+
114+
# %%
115+
# Main Function
116+
# -----------
117+
def main() -> None:
118+
check(1024, 4096, 128256)
119+
120+
121+
if __name__ == "__main__":
122+
main()

helion/autotuner/base_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"misaligned address", # CUDA Error
5353
"PassManager::run failed", # Triton Error
5454
"illegal memory access", # CUDA Error
55+
"exceeds triton maximum tensor numel", # Triton Error
5556
],
5657
)
5758
)
@@ -147,7 +148,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
147148
except PTXASError:
148149
self.log.warning(f"PTXASError compiling config: {config}")
149150
except Exception as e:
150-
if not _expected_errors_regexp.search(str(e)):
151+
if not _expected_errors_regexp.search(str(e) + str(e.__cause__)):
151152
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
152153
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
153154
return inf

test/test_examples.expected

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,77 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
840840
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
841841
return out
842842

843+
--- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
844+
from __future__ import annotations
845+
846+
import torch
847+
import triton
848+
import triton.language as tl
849+
from torch._inductor.runtime.triton_helpers import math as tl_math
850+
from helion.runtime import default_launcher as _default_launcher
851+
852+
@triton.jit
853+
def _helion_fused_linear_jsd_fwd(student_logits, teacher_logits, loss, student_input_size_0, student_weight_size_0, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
854+
pid_0 = tl.program_id(0)
855+
offset_0 = pid_0 * _BLOCK_SIZE_0
856+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
857+
mask_0 = indices_0 < student_input_size_0
858+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
859+
mask_1 = indices_1 < student_weight_size_0
860+
load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
861+
v_0 = load / temperature
862+
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, float('-inf'))
863+
amax = tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1])
864+
v_1 = v_0 - amax
865+
v_2 = tl_math.exp(v_1)
866+
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, 0)
867+
sum_1 = tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1])
868+
v_3 = tl_math.log(sum_1)
869+
v_4 = v_1 - v_3
870+
load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
871+
v_5 = load_1 / temperature
872+
_mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, float('-inf'))
873+
amax_1 = tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1])
874+
v_6 = v_5 - amax_1
875+
v_7 = tl_math.exp(v_6)
876+
_mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, 0)
877+
sum_2 = tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1])
878+
v_8 = tl_math.log(sum_2)
879+
v_9 = v_6 - v_8
880+
student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
881+
teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
882+
v_10 = tl_math.exp(student_prob_1)
883+
v_11 = tl_math.exp(teacher_prob_1)
884+
v_12 = tl_math.exp(student_prob_1)
885+
v_13 = v_11 - v_12
886+
v_14 = v_13 * beta
887+
v_15 = v_10 + v_14
888+
v_16 = tl_math.log(v_15)
889+
v_17 = teacher_prob_1 - v_16
890+
v_18 = tl_math.exp(teacher_prob_1)
891+
v_19 = v_18 * v_17
892+
_mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, 0)
893+
teacher_div = tl.sum(_mask_to_4, 1)
894+
v_20 = tl_math.log(v_15)
895+
v_21 = student_prob_1 - v_20
896+
v_22 = tl_math.exp(student_prob_1)
897+
v_23 = v_22 * v_21
898+
_mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, 0)
899+
student_div = tl.sum(_mask_to_5, 1)
900+
v_24 = teacher_div - student_div
901+
v_25 = v_24 * beta
902+
v_26 = student_div + v_25
903+
tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
904+
905+
def fused_linear_jsd_fwd(beta: float, ignore_index: int, temperature: float, student_weight: torch.Tensor, teacher_weight: torch.Tensor, student_input: torch.Tensor, teacher_input: torch.Tensor, *, _launcher=_default_launcher):
906+
student_logits = student_input @ student_weight.T
907+
teacher_logits = teacher_input @ teacher_weight.T
908+
loss = student_logits.new_empty(student_input.shape[0], dtype=torch.float)
909+
_BLOCK_SIZE_0 = 32
910+
_RDIM_SIZE_1 = triton.next_power_of_2(student_weight.size(0))
911+
_launcher(_helion_fused_linear_jsd_fwd, (triton.cdiv(student_input.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_input.size(0), student_weight.size(0), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
912+
return (loss / student_logits.shape[0]).sum()
913+
843914
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
844915
from __future__ import annotations
845916

test/test_examples.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
import unittest
44

5-
from packaging import version
65
import torch
76

8-
from helion._testing import DEVICE
9-
from helion._testing import EXAMPLES_DIR
10-
from helion._testing import RefEagerTestBase
11-
from helion._testing import TestCase
12-
from helion._testing import check_example
13-
from helion._testing import import_path
14-
from helion._testing import skipIfRefEager
15-
from helion._testing import skipIfRocm
7+
from helion._testing import (
8+
check_example,
9+
DEVICE,
10+
EXAMPLES_DIR,
11+
import_path,
12+
RefEagerTestBase,
13+
skipIfRefEager,
14+
skipIfRocm,
15+
TestCase,
16+
)
17+
18+
from packaging import version
1619

1720
torch.backends.cuda.matmul.fp32_precision = "tf32"
1821
torch.backends.cudnn.conv.fp32_precision = "tf32"
@@ -629,6 +632,41 @@ def test_layernorm(self):
629632
)
630633
)
631634

635+
def test_fused_linear_jsd(self):
636+
beta = 0.5
637+
ignore_index = 1
638+
temperature = 1.0
639+
m, n, k = 64, 128, 256
640+
641+
student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
642+
teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
643+
student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
644+
teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
645+
646+
args = (
647+
beta,
648+
ignore_index,
649+
temperature,
650+
student_weight,
651+
teacher_weight,
652+
student_input,
653+
teacher_input,
654+
)
655+
656+
# Import and use the reference implementation
657+
mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py")
658+
expected = mod.fused_linear_jsd_pytorch(*args)
659+
660+
self.assertExpectedJournal(
661+
check_example(
662+
"fused_linear_jsd",
663+
args,
664+
expected,
665+
fn_name="fused_linear_jsd_fwd",
666+
block_sizes=[32],
667+
)
668+
)
669+
632670

633671
if __name__ == "__main__":
634672
unittest.main()

0 commit comments

Comments
 (0)