Skip to content

Commit beaab8c

Browse files
committed
[Executorch] make slice_copy parallel
Pull Request resolved: #15830 When doing large prefills in LLMs, slice_copy takes about 5-10% time. Mainly coming from slicing in the rope implementation. Differential Revision: [D85532081](https://our.internmc.facebook.com/intern/diff/D85532081/) ghstack-source-id: 324784683
1 parent 0f529ef commit beaab8c

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

kernels/portable/cpu/util/slice_util.cpp

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <c10/util/irange.h>
1010
#include <executorch/kernels/portable/cpu/util/slice_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1213
#include <cstring>
1314

1415
namespace torch {
@@ -202,12 +203,44 @@ void compute_slice(
202203
InvalidArgument,
203204
/* void */,
204205
"out.nbytes() is smaller than the expected slice size.");
205-
for (const auto i : c10::irange(leading_dims)) {
206-
const char* src = input_data + (i * dim_length + start) * length_per_step;
207-
for ([[maybe_unused]] const auto j : c10::irange(length)) {
208-
memcpy(dest, src, length_per_step);
209-
src += step * length_per_step;
210-
dest += length_per_step;
206+
// Thresholds for enabling multithreading:
207+
// - Minimum number of leading dimensions: 8
208+
// - Minimum total elements to copy: 32768 (GRAIN_SIZE)
209+
constexpr int64_t MIN_LEADING_DIMS_FOR_MT = 8;
210+
constexpr int64_t MIN_ELEMENTS_FOR_MT =
211+
executorch::extension::internal::GRAIN_SIZE;
212+
213+
const int64_t total_elements = leading_dims * length * trailing_dims;
214+
const bool use_multithreading = leading_dims >= MIN_LEADING_DIMS_FOR_MT &&
215+
total_elements >= MIN_ELEMENTS_FOR_MT;
216+
217+
if (use_multithreading) {
218+
// Use parallel_for to distribute work across leading dimensions
219+
// Calculate grain size based on number of elements per leading dimension
220+
const int64_t grain_size = MIN_LEADING_DIMS_FOR_MT;
221+
222+
executorch::extension::parallel_for(
223+
0, leading_dims, grain_size, [&](const auto begin, const auto end) {
224+
for (const auto i : c10::irange(begin, end)) {
225+
const char* src =
226+
input_data + (i * dim_length + start) * length_per_step;
227+
char* local_dest = dest + i * length * length_per_step;
228+
for ([[maybe_unused]] const auto j : c10::irange(length)) {
229+
memcpy(local_dest, src, length_per_step);
230+
src += step * length_per_step;
231+
local_dest += length_per_step;
232+
}
233+
}
234+
});
235+
} else {
236+
// Single-threaded path for small workloads
237+
for (const auto i : c10::irange(leading_dims)) {
238+
const char* src = input_data + (i * dim_length + start) * length_per_step;
239+
for ([[maybe_unused]] const auto j : c10::irange(length)) {
240+
memcpy(dest, src, length_per_step);
241+
src += step * length_per_step;
242+
dest += length_per_step;
243+
}
211244
}
212245
}
213246
}

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def define_common_targets():
292292
exported_headers = ["slice_util.h"],
293293
deps = [
294294
"//executorch/runtime/kernel:kernel_includes",
295+
"//executorch/extension/threadpool:threadpool",
295296
],
296297
visibility = ["//executorch/kernels/portable/cpu/..."],
297298
)

0 commit comments

Comments
 (0)