|
9 | 9 | #include <c10/util/irange.h> |
10 | 10 | #include <executorch/kernels/portable/cpu/util/slice_util.h> |
11 | 11 | #include <executorch/runtime/kernel/kernel_includes.h> |
| 12 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
12 | 13 | #include <cstring> |
13 | 14 |
|
14 | 15 | namespace torch { |
@@ -202,12 +203,44 @@ void compute_slice( |
202 | 203 | InvalidArgument, |
203 | 204 | /* void */, |
204 | 205 | "out.nbytes() is smaller than the expected slice size."); |
205 | | - for (const auto i : c10::irange(leading_dims)) { |
206 | | - const char* src = input_data + (i * dim_length + start) * length_per_step; |
207 | | - for ([[maybe_unused]] const auto j : c10::irange(length)) { |
208 | | - memcpy(dest, src, length_per_step); |
209 | | - src += step * length_per_step; |
210 | | - dest += length_per_step; |
| 206 | + // Thresholds for enabling multithreading: |
| 207 | + // - Minimum number of leading dimensions: 8 |
| 208 | + // - Minimum total elements to copy: 32768 (GRAIN_SIZE) |
| 209 | + constexpr int64_t MIN_LEADING_DIMS_FOR_MT = 8; |
| 210 | + constexpr int64_t MIN_ELEMENTS_FOR_MT = |
| 211 | + executorch::extension::internal::GRAIN_SIZE; |
| 212 | + |
| 213 | + const int64_t total_elements = leading_dims * length * trailing_dims; |
| 214 | + const bool use_multithreading = leading_dims >= MIN_LEADING_DIMS_FOR_MT && |
| 215 | + total_elements >= MIN_ELEMENTS_FOR_MT; |
| 216 | + |
| 217 | + if (use_multithreading) { |
| 218 | + // Use parallel_for to distribute work across leading dimensions |
| 219 | + // Calculate grain size based on number of elements per leading dimension |
| 220 | + const int64_t grain_size = MIN_LEADING_DIMS_FOR_MT; |
| 221 | + |
| 222 | + executorch::extension::parallel_for( |
| 223 | + 0, leading_dims, grain_size, [&](const auto begin, const auto end) { |
| 224 | + for (const auto i : c10::irange(begin, end)) { |
| 225 | + const char* src = |
| 226 | + input_data + (i * dim_length + start) * length_per_step; |
| 227 | + char* local_dest = dest + i * length * length_per_step; |
| 228 | + for ([[maybe_unused]] const auto j : c10::irange(length)) { |
| 229 | + memcpy(local_dest, src, length_per_step); |
| 230 | + src += step * length_per_step; |
| 231 | + local_dest += length_per_step; |
| 232 | + } |
| 233 | + } |
| 234 | + }); |
| 235 | + } else { |
| 236 | + // Single-threaded path for small workloads |
| 237 | + for (const auto i : c10::irange(leading_dims)) { |
| 238 | + const char* src = input_data + (i * dim_length + start) * length_per_step; |
| 239 | + for ([[maybe_unused]] const auto j : c10::irange(length)) { |
| 240 | + memcpy(dest, src, length_per_step); |
| 241 | + src += step * length_per_step; |
| 242 | + dest += length_per_step; |
| 243 | + } |
211 | 244 | } |
212 | 245 | } |
213 | 246 | } |
|
0 commit comments