Skip to content

Commit e5a8e0e

Browse files
committed
perf: Optimize op::transpose
1 parent dddbf4d commit e5a8e0e

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

src/Operation/OpTransformCpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ Tensor transposeOpCpuImpl(const Tensor& self, int64_t dim0, int64_t dim1) {
6969
return self.clone();
7070
}
7171

72+
if ((self.size(dim0) == 1 || self.size(dim1) == 1) && std::abs(dim0 - dim1) == 1) {
73+
SizeVector retShape(self.shape());
74+
std::swap(retShape[dim0], retShape[dim1]);
75+
return op::view(self, retShape);
76+
}
77+
7278
if (self.dim() == 2) {
7379
return transpose2dOpCpuImpl<T>(self);
7480
}

src/Operation/OpTransformCuda.cuh

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -407,29 +407,83 @@ Tensor transposeOpCudaImpl(const Tensor& self, int64_t dim0, int64_t dim1) {
407407
return self.clone();
408408
}
409409

410-
if (self.dim() == 2) {
411-
return transpose2dOpCudaImpl<T>(self);
410+
if (dim0 > dim1) {
411+
std::swap(dim0, dim1);
412412
}
413413

414414
SizeVector retShape(self.shape());
415415
std::swap(retShape[dim0], retShape[dim1]);
416-
auto ret = Tensor::empty(retShape, self.options().noGrad());
417416

418-
const auto* selfPtr = self.dataPtr<T>();
419-
auto* retPtr = ret.dataPtr<T>();
417+
if ((self.size(dim0) == 1 || self.size(dim1) == 1) && dim1 - dim0 == 1) {
418+
return op::view(self, retShape);
419+
}
420+
421+
if (self.dim() == 2) {
422+
return transpose2dOpCudaImpl<T>(self);
423+
}
424+
425+
SizeVector mergedShape;
426+
SizeVector mergedOutShape;
427+
428+
int64_t preSize = 1;
429+
for (int64_t i = 0; i < dim0; i++) {
430+
preSize *= self.size(i);
431+
}
432+
if (preSize > 1) {
433+
mergedShape.pushBack(preSize);
434+
mergedOutShape.pushBack(preSize);
435+
}
436+
437+
mergedShape.pushBack(self.size(dim0));
438+
mergedOutShape.pushBack(self.size(dim1));
439+
440+
int64_t midSize = 1;
441+
for (int64_t i = dim0 + 1; i < dim1; i++) {
442+
midSize *= self.size(i);
443+
}
444+
if (midSize > 1) {
445+
mergedShape.pushBack(midSize);
446+
mergedOutShape.pushBack(midSize);
447+
}
448+
449+
mergedShape.pushBack(self.size(dim1));
450+
mergedOutShape.pushBack(self.size(dim0));
451+
452+
int64_t postSize = 1;
453+
for (int64_t i = dim1 + 1; i < self.dim(); i++) {
454+
postSize *= self.size(i);
455+
}
456+
if (postSize > 1) {
457+
mergedShape.pushBack(postSize);
458+
mergedOutShape.pushBack(postSize);
459+
}
460+
461+
Tensor mergedInput = op::reshape(self, mergedShape);
462+
Tensor mergedOutput = Tensor::empty(mergedOutShape, self.options().noGrad());
463+
464+
int64_t newDim0 = 0, newDim1 = 0;
465+
int pos = 0;
466+
if (preSize > 1) {
467+
pos++;
468+
}
469+
newDim0 = pos++;
470+
if (midSize > 1) {
471+
pos++;
472+
}
473+
newDim1 = pos;
420474

421475
DimArray<int64_t> inStrides{};
422476
DimArray<int64_t> outStrides{};
423-
424-
for (auto i = 0; i < self.dim(); i++) {
425-
inStrides.data[i] = self.stride(i);
426-
outStrides.data[i] = ret.stride(i);
477+
for (auto i = 0; i < mergedInput.dim(); i++) {
478+
inStrides.data[i] = mergedInput.stride(i);
479+
outStrides.data[i] = mergedOutput.stride(i);
427480
}
428481

429-
auto params = cuda::getKernelLaunchParams(self.device().index, self.numel());
430-
CUDA_LAUNCH_KERNEL(kTransposeND<T>, params, retPtr, selfPtr, self.dim(), dim0, dim1, self.numel(), outStrides,
431-
inStrides);
432-
return ret;
482+
auto params = cuda::getKernelLaunchParams(self.device().index, mergedOutput.numel());
483+
CUDA_LAUNCH_KERNEL(kTransposeND<T>, params, mergedOutput.dataPtr<T>(), mergedInput.dataPtr<T>(), mergedInput.dim(),
484+
newDim0, newDim1, mergedOutput.numel(), outStrides, inStrides);
485+
486+
return op::reshape(mergedOutput, retShape);
433487
}
434488

435489
template <typename T>

0 commit comments

Comments
 (0)