Skip to content

Commit 01c6b5b

Browse files
committed
fix: Fix op::tril/triu when dim > 2
1 parent e5a8e0e commit 01c6b5b

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

src/Operation/OpTransformCpu.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,32 @@ void indexPutAdvanceOpCpuImpl(Tensor& self, ArrayView<Tensor> indices, const Ten
181181
template <typename T, bool LOWER>
182182
Tensor triangleOpCpuImpl(const Tensor& self, int64_t diagonal) {
183183
auto ret = Tensor::empty(self.shape(), self.options().noGrad());
184-
const auto rows = self.shape(0);
185-
const auto cols = self.shape(1);
184+
185+
const auto dims = self.dim();
186+
const auto rows = self.shape(dims - 2);
187+
const auto cols = self.shape(dims - 1);
188+
189+
int64_t batch = 1;
190+
for (int i = 0; i < dims - 2; i++) {
191+
batch *= self.shape(i);
192+
}
186193

187194
const T* selfPtr = self.dataPtr<T>();
188195
T* retPtr = ret.dataPtr<T>();
189196

190-
int64_t idx = 0;
191-
for (auto i = 0; i < rows; i++) {
192-
idx = i * cols;
193-
for (auto j = 0; j < cols; j++) {
194-
if ((LOWER && j <= i + diagonal) || (!LOWER && j >= i + diagonal)) {
195-
retPtr[idx] = selfPtr[idx];
196-
} else {
197-
retPtr[idx] = 0;
197+
const int64_t matrixSize = rows * cols;
198+
for (int64_t b = 0; b < batch; b++) {
199+
const T* selfBatchPtr = selfPtr + b * matrixSize;
200+
T* retBatchPtr = retPtr + b * matrixSize;
201+
202+
for (int64_t i = 0; i < rows; i++) {
203+
for (int64_t j = 0; j < cols; j++) {
204+
if ((LOWER && j <= i + diagonal) || (!LOWER && j >= i + diagonal)) {
205+
retBatchPtr[i * cols + j] = selfBatchPtr[i * cols + j];
206+
} else {
207+
retBatchPtr[i * cols + j] = static_cast<T>(0);
208+
}
198209
}
199-
idx++;
200210
}
201211
}
202212
return ret;

src/Operation/OpTransformCuda.cuh

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,18 @@ __global__ void kIndexPut2D(T* self, const int64_t* indices0, const int64_t* ind
190190
}
191191

192192
template <typename T, bool LOWER>
193-
__global__ void kTriangle(T* ret, const T* t, const int64_t rows, const int64_t cols, const int64_t diagonal) {
194-
auto i = blockIdx.y * blockDim.y + threadIdx.y;
195-
auto j = blockIdx.x * blockDim.x + threadIdx.x;
196-
197-
if (i < rows && j < cols) {
198-
const auto index = i * cols + j;
193+
__global__ void kTriangle(T* ret, const T* t, const int64_t batch, const int64_t rows, const int64_t cols,
194+
const int64_t diagonal, const int64_t matrixSize) {
195+
const int64_t b = blockIdx.z;
196+
const int64_t i = blockIdx.y * blockDim.y + threadIdx.y;
197+
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
198+
199+
if (b < batch && i < rows && j < cols) {
200+
const int64_t index = b * matrixSize + i * cols + j;
199201
if ((LOWER && j <= i + diagonal) || (!LOWER && j >= i + diagonal)) {
200202
ret[index] = t[index];
201203
} else {
202-
ret[index] = 0;
204+
ret[index] = static_cast<T>(0);
203205
}
204206
}
205207
}
@@ -568,17 +570,25 @@ void indexPutAdvanceOpCudaImpl(Tensor& self, ArrayView<Tensor> indices, const Te
568570
template <typename T, bool LOWER>
569571
Tensor triangleOpCudaImpl(const Tensor& self, int64_t diagonal) {
570572
auto ret = Tensor::empty(self.shape(), self.options().noGrad());
571-
const auto rows = self.shape(0);
572-
const auto cols = self.shape(1);
573573

574+
const auto dims = self.dim();
575+
const auto rows = self.shape(dims - 2);
576+
const auto cols = self.shape(dims - 1);
577+
578+
int64_t batch = 1;
579+
for (int i = 0; i < dims - 2; i++) {
580+
batch *= self.shape(i);
581+
}
582+
583+
const int64_t matrixSize = rows * cols;
574584
const T* selfPtr = self.dataPtr<T>();
575585
T* retPtr = ret.dataPtr<T>();
576586

577-
dim3 blockSize(CUDA_WARP_SIZE, CUDA_WARP_SIZE);
578-
dim3 gridSize((cols + blockSize.x - 1) / blockSize.x, (rows + blockSize.y - 1) / blockSize.y);
587+
dim3 blockSize(16, 16);
588+
dim3 gridSize((cols + blockSize.x - 1) / blockSize.x, (rows + blockSize.y - 1) / blockSize.y, batch);
579589

580590
const auto stream = cuda::getCurrentCUDAStream(self.device().index).stream;
581-
kTriangle<T, LOWER><<<gridSize, blockSize, 0, stream>>>(retPtr, selfPtr, rows, cols, diagonal);
591+
kTriangle<T, LOWER><<<gridSize, blockSize, 0, stream>>>(retPtr, selfPtr, batch, rows, cols, diagonal, matrixSize);
582592
CUDA_KERNEL_CHECK();
583593
return ret;
584594
}

0 commit comments

Comments
 (0)