99
1010namespace vllm {
1111
12+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
13+ bool act_first>
14+ __device__ __forceinline__ scalar_t compute (const scalar_t & x,
15+ const scalar_t & y) {
16+ return act_first ? ACT_FN (x) * y : x * ACT_FN (y);
17+ }
1218// Activation and gating kernel template.
13- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
19+
20+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
21+ bool act_first>
1422__global__ void act_and_mul_kernel (
1523 scalar_t * __restrict__ out, // [..., d]
1624 const scalar_t * __restrict__ input, // [..., 2, d]
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
1927 for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
2028 const scalar_t x = VLLM_LDG (&input[token_idx * 2 * d + idx]);
2129 const scalar_t y = VLLM_LDG (&input[token_idx * 2 * d + d + idx]);
22- out[token_idx * d + idx] = ACT_FN (x) * y ;
30+ out[token_idx * d + idx] = compute< scalar_t , ACT_FN, act_first>(x, y) ;
2331 }
2432}
2533
@@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
5563} // namespace vllm
5664
5765// Launch activation and gating kernel.
58- #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL ) \
66+ // Use ACT_FIRST (bool) indicating whether to apply the activation function
67+ // first.
68+ #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
5969 int d = input.size(-1 ) / 2 ; \
6070 int64_t num_tokens = input.numel() / input.size(-1 ); \
6171 dim3 grid (num_tokens); \
@@ -64,27 +74,35 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
6474 const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
6575 VLLM_DISPATCH_FLOATING_TYPES ( \
6676 input.scalar_type(), "act_and_mul_kernel", [&] { \
67- vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
77+ vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >, ACT_FIRST> \
6878 <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
6979 input.data_ptr <scalar_t >(), d); \
7080 });
7181
7282void silu_and_mul (torch::Tensor& out, // [..., d]
7383 torch::Tensor& input) // [..., 2 * d]
7484{
75- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel);
85+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, true );
86+ }
87+
88+ void mul_and_silu (torch::Tensor& out, // [..., d]
89+ torch::Tensor& input) // [..., 2 * d]
90+ {
91+ // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
92+ // applies the silu to the latter half of the input.
93+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, false );
7694}
7795
7896void gelu_and_mul (torch::Tensor& out, // [..., d]
7997 torch::Tensor& input) // [..., 2 * d]
8098{
81- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel);
99+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel, true );
82100}
83101
84102void gelu_tanh_and_mul (torch::Tensor& out, // [..., d]
85103 torch::Tensor& input) // [..., 2 * d]
86104{
87- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel);
105+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel, true );
88106}
89107
90108namespace vllm {
0 commit comments