Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"

namespace cldnn {

/// @brief
/// @details
struct moe_scatter_reduction : public primitive_base<moe_scatter_reduction> {
CLDNN_DECLARE_PRIMITIVE(moe_scatter_reduction)

moe_scatter_reduction() : primitive_base("", {}) {}

/// @brief Constructs moe_scatter_reduction primitive.
///
/// @param id This primitive id.
/// @param input Input data primitive id.
/// @param experts_per_token sorted topk expert id per token
/// @param expert_weights_per_token sorted topk expert id weight per token
/// @param tokens_per_expert tokens per expert
/// @param experts_info_offsets offset of each expert's info from the tokens_per_expert
/// @param tokens_len_per_expert tokens len_per_expert
moe_scatter_reduction(const primitive_id& id,
const input_info& data,
const input_info& experts_per_token,
const input_info& expert_weights_per_token,
const input_info& tokens_per_expert,
const input_info& experts_info_offsets,
const input_info& tokens_len_per_expert,
int32_t num_active_experts_per_token = 0)
: primitive_base(id, {data, experts_per_token, expert_weights_per_token, tokens_per_expert,
experts_info_offsets, tokens_len_per_expert}), num_active_experts_per_token(num_active_experts_per_token) {}

int32_t num_active_experts_per_token = 0;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, num_active_experts_per_token);
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const moe_scatter_reduction>(rhs);

return num_active_experts_per_token == rhs_casted.num_active_experts_per_token;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<moe_scatter_reduction>::save(ob);
ob << num_active_experts_per_token;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<moe_scatter_reduction>::load(ib);
ib >> num_active_experts_per_token;
}
};
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "moe_scatter_reduction.hpp"

#include "../common_utils/dispatch_utils.hpp"
#include "../common_utils/jitter.hpp"
#include "intel_gpu/primitives/moe_scatter_reduction.hpp"
#include "../primitive_ocl_base.hpp"
#include "../utils/kernel_generator.hpp"

namespace ov::intel_gpu::ocl {
namespace {

class MoeScatterReductionRefGenerator : public KernelGenerator {
public:
MoeScatterReductionRefGenerator() : KernelGenerator("moe_scatter_reduction_ref") {}

protected:
static size_t GetBlockSize(const RuntimeParams& params) {
const auto& input = params.get_input_layout(0);
size_t vec_size = 1;
switch (input.data_type) {
case ov::element::i8:
case ov::element::u8:
vec_size = 16;
break;
case ov::element::f16:
vec_size = 8;
break;
case ov::element::f32:
case ov::element::i32:
vec_size = 4;
break;
case ov::element::i64:
vec_size = 2;
break;
default:
vec_size = 1;
break;
}
return vec_size;
}

static auto calc_thread_count(RuntimeParams& params, const int vector_size, const int hidden_size) {
auto max_wgs = params.get_program().get_engine().get_device_info().max_work_group_size;
const uint64_t threads_needed = (hidden_size + vector_size - 1) / vector_size;
size_t local_threads_needed = std::min(threads_needed, max_wgs);
size_t batches_per_thread = 1;
size_t unaligned_elements = 0;

if (threads_needed <= max_wgs) {
batches_per_thread = 1;
unaligned_elements = hidden_size % vector_size;
} else {
batches_per_thread = (threads_needed + max_wgs - 1) / max_wgs;
auto new_block_size = batches_per_thread * vector_size;
unaligned_elements = hidden_size % new_block_size;

local_threads_needed = hidden_size / new_block_size;
auto partialblock = (hidden_size % new_block_size != 0) ? 1 : 0;
local_threads_needed += partialblock;
}

return std::tuple{local_threads_needed, batches_per_thread, unaligned_elements};
}

[[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override {
auto jit = KernelGenerator::get_jit_constants(params);
auto in_l = params.input_layouts[0];
auto hidden_size = extract_channel(ChannelName::FEATURE, in_l);
auto block_size = GetBlockSize(params);
auto [local_threads_count, batches_per_thread, unaligned_elements] = calc_thread_count(
const_cast<RuntimeParams&>(params), block_size, hidden_size);

const auto& desc = params.typed_desc<moe_scatter_reduction>();

jit.make("ACTIVE_EXPERTS", desc->num_active_experts_per_token);
jit.make("HIDDEN_SIZE", hidden_size);
jit.make("VEC_BLK_SIZE", block_size);
jit.make("BATCHES_PER_THREAD", batches_per_thread);
jit.make("UNALIGNED_ELEMENTS", unaligned_elements);

return jit;
}

Arguments get_arguments_desc(const RuntimeParams& params) const override {
Arguments args;
if (params.is_dynamic()) {
args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0});
}

uint32_t num_of_inputs = 6;

for (uint32_t i = 0; i < num_of_inputs; i++) {
args.push_back({ArgumentDescriptor::Types::INPUT, i});
}

args.push_back({ArgumentDescriptor::Types::OUTPUT, 0});

return args;
}

[[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override {
return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) {
auto& wgs = kd.params.workGroups;

if (!params.is_dynamic()) {
auto hidden_size = extract_channel(ChannelName::FEATURE, params.input_layouts[0]);
auto block_size = GetBlockSize(params);
auto [local_threads_count, batches_per_thread, unaligned_elements] = calc_thread_count(
const_cast<RuntimeParams&>(params), block_size, hidden_size);

auto num_tokens = extract_channel(ChannelName::BATCH, params.input_layouts[1]);

wgs.global = {num_tokens * local_threads_count, 1, 1};
wgs.local = { local_threads_count, 1, 1};
}
}};
}
};

class MoeScatterReductionRefImpl : public PrimitiveImplOCL {
public:
DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MoeScatterReductionRefImpl)

Stage::Ptr moe_scatter_reduction = make_stage<MoeScatterReductionRefGenerator>();

MoeScatterReductionRefImpl() : PrimitiveImplOCL(MoeScatterReductionRef::get_type_info_static()) {}
MoeScatterReductionRefImpl(const program_node& node, const RuntimeParams& params) : MoeScatterReductionRefImpl() {
add_stage(moe_scatter_reduction, params);
}

[[nodiscard]] std::unique_ptr<primitive_impl> clone() const override {
return make_deep_copy<MoeScatterReductionRefImpl>(this);
}
};

} // namespace

std::unique_ptr<primitive_impl> MoeScatterReductionRef::create_impl(const program_node& node, const RuntimeParams& params) const {
assert(node.is_type<moe_scatter_reduction>());
return std::make_unique<MoeScatterReductionRefImpl>(node, params);
}

} // namespace ov::intel_gpu::ocl

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_scatter_reduction)
BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::MoeScatterReductionRefImpl)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <utility>

#include "program_node.h"
#include "registry/implementation_manager.hpp"

using namespace cldnn; // TODO: Remove once namespaces are aligned

namespace ov::intel_gpu::ocl {

struct MoeScatterReductionRef : public ImplementationManager {
OV_GPU_PRIMITIVE_IMPL("ocl::moe_scatter_reduction")
explicit MoeScatterReductionRef(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, std::move(vf)) {}
[[nodiscard]] std::unique_ptr<primitive_impl> create_impl(const program_node& node, const RuntimeParams& params) const override;
[[nodiscard]] bool validate_impl(const program_node& node) const override {
static constexpr std::array supported_fmts = {
format::bfyx,
};

static constexpr std::array supported_types = {
ov::element::f32,
ov::element::f16,
ov::element::i32,
ov::element::i64,
ov::element::i8,
ov::element::u8,
};

const auto& in0_layout = node.get_input_layout(0);
const auto& out_layout = node.get_output_layout(0);
const auto& input_pshapes = in0_layout.get_partial_shape();

if (input_pshapes.rank().get_length() != 2 || input_pshapes[1].is_dynamic()) {
return false;
}

if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) {
return false;
}

if (!one_of(in0_layout.data_type, supported_types) || !one_of(out_layout.data_type, supported_types)) {
return false;
}

return true;
}
};

} // namespace ov::intel_gpu::ocl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/common.cl"
#include "include/fetch_utils.cl"

#define VLOAD CAT(vload, VEC_BLK_SIZE)
#define VSTORE CAT(vstore, VEC_BLK_SIZE)
#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_BLK_SIZE)
#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_BLK_SIZE)

KERNEL(moe_scatter_reduction_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* experts_per_token,
const __global INPUT2_TYPE* expert_weights,
const __global INPUT3_TYPE* tokens_per_expert,
const __global INPUT4_TYPE* experts_start_offset,
const __global INPUT5_TYPE* tokens_len_per_expert,
__global OUTPUT_TYPE* output
)
{
const uint token_group_id = (uint)get_group_id(0);
const uint threads_index = (uint)get_local_id(0);

OUTPUT_VEC_TYPE output_vec[BATCHES_PER_THREAD];

#if UNALIGNED_ELEMENTS > 0
OUTPUT_TYPE output_scalar[UNALIGNED_ELEMENTS];
#endif

uint dest_index = token_group_id * HIDDEN_SIZE;
uint output_pos = dest_index + threads_index * VEC_BLK_SIZE * BATCHES_PER_THREAD;

for (uint i = 0; i < BATCHES_PER_THREAD; i++) {
output_vec[i] = TO_OUTPUT_TYPE(0);
}

#if UNALIGNED_ELEMENTS > 0
for (uint i = 0; i < UNALIGNED_ELEMENTS; i++) {
output_scalar[i] = TO_OUTPUT_TYPE(0);
}
#endif

for (uint i = 0; i < ACTIVE_EXPERTS; i++) {
INPUT1_TYPE expert_id = experts_per_token[token_group_id * ACTIVE_EXPERTS + i];
INPUT2_TYPE expert_weight = expert_weights[token_group_id * ACTIVE_EXPERTS + i];
INPUT5_TYPE token_len = tokens_len_per_expert[expert_id];
INPUT4_TYPE expert_offset = experts_start_offset[expert_id];

uint input_offset = 0;
for (uint j = 0; j < token_len; j++) {
if (tokens_per_expert[expert_offset + j] == token_group_id) {
input_offset = expert_offset + j;
break;
}
}

for (uint j = 0; j < BATCHES_PER_THREAD; j++) {
const uint input_pos = input_offset * HIDDEN_SIZE + j * VEC_BLK_SIZE + threads_index * VEC_BLK_SIZE * BATCHES_PER_THREAD;

#if UNALIGNED_ELEMENTS > 0
if ((threads_index == get_local_size(0) - 1) && (j == 0)) {
uint input_pos_unaligned = input_pos;
for (uint k = 0; k < UNALIGNED_ELEMENTS; k++) {
output_scalar[k] += input[input_pos_unaligned] * expert_weight;
input_pos_unaligned++;
}
} else {
#endif
INPUT_VEC_TYPE input_data = VLOAD(0, &input[input_pos]);
input_data *= expert_weight;
output_vec[j] += input_data;
#if UNALIGNED_ELEMENTS > 0
}
#endif
}
}

#if UNALIGNED_ELEMENTS > 0
if ((threads_index == get_local_size(0) - 1)) {
uint output_pos_unaligned = output_pos;
for (uint s = 0; s < UNALIGNED_ELEMENTS; s++) {
output[output_pos_unaligned] = output_scalar[s];
output_pos_unaligned++;
}
} else {
#endif
for (uint v = 0; v < BATCHES_PER_THREAD; v++) {
const uint out_pos = output_pos + v * VEC_BLK_SIZE;
VSTORE(output_vec[v], 0, &output[out_pos]);
}
#if UNALIGNED_ELEMENTS > 0
}
#endif
}
Loading
Loading