@@ -407,29 +407,83 @@ Tensor transposeOpCudaImpl(const Tensor& self, int64_t dim0, int64_t dim1) {
407407 return self.clone ();
408408 }
409409
410- if (self. dim () == 2 ) {
411- return transpose2dOpCudaImpl<T>(self );
410+ if (dim0 > dim1 ) {
411+ std::swap (dim0, dim1 );
412412 }
413413
414414 SizeVector retShape (self.shape ());
415415 std::swap (retShape[dim0], retShape[dim1]);
416- auto ret = Tensor::empty (retShape, self.options ().noGrad ());
417416
418- const auto * selfPtr = self.dataPtr <T>();
419- auto * retPtr = ret.dataPtr <T>();
417+ if ((self.size (dim0) == 1 || self.size (dim1) == 1 ) && dim1 - dim0 == 1 ) {
418+ return op::view (self, retShape);
419+ }
420+
421+ if (self.dim () == 2 ) {
422+ return transpose2dOpCudaImpl<T>(self);
423+ }
424+
425+ SizeVector mergedShape;
426+ SizeVector mergedOutShape;
427+
428+ int64_t preSize = 1 ;
429+ for (int64_t i = 0 ; i < dim0; i++) {
430+ preSize *= self.size (i);
431+ }
432+ if (preSize > 1 ) {
433+ mergedShape.pushBack (preSize);
434+ mergedOutShape.pushBack (preSize);
435+ }
436+
437+ mergedShape.pushBack (self.size (dim0));
438+ mergedOutShape.pushBack (self.size (dim1));
439+
440+ int64_t midSize = 1 ;
441+ for (int64_t i = dim0 + 1 ; i < dim1; i++) {
442+ midSize *= self.size (i);
443+ }
444+ if (midSize > 1 ) {
445+ mergedShape.pushBack (midSize);
446+ mergedOutShape.pushBack (midSize);
447+ }
448+
449+ mergedShape.pushBack (self.size (dim1));
450+ mergedOutShape.pushBack (self.size (dim0));
451+
452+ int64_t postSize = 1 ;
453+ for (int64_t i = dim1 + 1 ; i < self.dim (); i++) {
454+ postSize *= self.size (i);
455+ }
456+ if (postSize > 1 ) {
457+ mergedShape.pushBack (postSize);
458+ mergedOutShape.pushBack (postSize);
459+ }
460+
461+ Tensor mergedInput = op::reshape (self, mergedShape);
462+ Tensor mergedOutput = Tensor::empty (mergedOutShape, self.options ().noGrad ());
463+
464+ int64_t newDim0 = 0 , newDim1 = 0 ;
465+ int pos = 0 ;
466+ if (preSize > 1 ) {
467+ pos++;
468+ }
469+ newDim0 = pos++;
470+ if (midSize > 1 ) {
471+ pos++;
472+ }
473+ newDim1 = pos;
420474
421475 DimArray<int64_t > inStrides{};
422476 DimArray<int64_t > outStrides{};
423-
424- for (auto i = 0 ; i < self.dim (); i++) {
425- inStrides.data [i] = self.stride (i);
426- outStrides.data [i] = ret.stride (i);
477+ for (auto i = 0 ; i < mergedInput.dim (); i++) {
478+ inStrides.data [i] = mergedInput.stride (i);
479+ outStrides.data [i] = mergedOutput.stride (i);
427480 }
428481
429- auto params = cuda::getKernelLaunchParams (self.device ().index , self.numel ());
430- CUDA_LAUNCH_KERNEL (kTransposeND <T>, params, retPtr, selfPtr, self.dim (), dim0, dim1, self.numel (), outStrides,
431- inStrides);
432- return ret;
482+ auto params = cuda::getKernelLaunchParams (self.device ().index , mergedOutput.numel ());
483+ CUDA_LAUNCH_KERNEL (kTransposeND <T>, params, mergedOutput.dataPtr <T>(), mergedInput.dataPtr <T>(), mergedInput.dim (),
484+ newDim0, newDim1, mergedOutput.numel (), outStrides, inStrides);
485+
486+ return op::reshape (mergedOutput, retShape);
433487}
434488
435489template <typename T>
0 commit comments