@@ -190,16 +190,18 @@ __global__ void kIndexPut2D(T* self, const int64_t* indices0, const int64_t* ind
190190}
191191
192192template  <typename  T, bool  LOWER>
193- __global__  void  kTriangle (T* ret, const  T* t, const  int64_t  rows, const  int64_t  cols, const  int64_t  diagonal) {
194-   auto  i = blockIdx .y  * blockDim .y  + threadIdx .y ;
195-   auto  j = blockIdx .x  * blockDim .x  + threadIdx .x ;
196- 
197-   if  (i < rows && j < cols) {
198-     const  auto  index = i * cols + j;
193+ __global__  void  kTriangle (T* ret, const  T* t, const  int64_t  batch, const  int64_t  rows, const  int64_t  cols,
194+                           const  int64_t  diagonal, const  int64_t  matrixSize) {
195+   const  int64_t  b = blockIdx .z ;
196+   const  int64_t  i = blockIdx .y  * blockDim .y  + threadIdx .y ;
197+   const  int64_t  j = blockIdx .x  * blockDim .x  + threadIdx .x ;
198+ 
199+   if  (b < batch && i < rows && j < cols) {
200+     const  int64_t  index = b * matrixSize + i * cols + j;
199201    if  ((LOWER && j <= i + diagonal) || (!LOWER && j >= i + diagonal)) {
200202      ret[index] = t[index];
201203    } else  {
202-       ret[index] = 0 ;
204+       ret[index] = static_cast <T>( 0 ) ;
203205    }
204206  }
205207}
@@ -568,17 +570,25 @@ void indexPutAdvanceOpCudaImpl(Tensor& self, ArrayView<Tensor> indices, const Te
568570template  <typename  T, bool  LOWER>
569571Tensor triangleOpCudaImpl (const  Tensor& self, int64_t  diagonal) {
570572  auto  ret = Tensor::empty (self.shape (), self.options ().noGrad ());
571-   const  auto  rows = self.shape (0 );
572-   const  auto  cols = self.shape (1 );
573573
574+   const  auto  dims = self.dim ();
575+   const  auto  rows = self.shape (dims - 2 );
576+   const  auto  cols = self.shape (dims - 1 );
577+ 
578+   int64_t  batch = 1 ;
579+   for  (int  i = 0 ; i < dims - 2 ; i++) {
580+     batch *= self.shape (i);
581+   }
582+ 
583+   const  int64_t  matrixSize = rows * cols;
574584  const  T* selfPtr = self.dataPtr <T>();
575585  T* retPtr = ret.dataPtr <T>();
576586
577-   dim3  blockSize (CUDA_WARP_SIZE, CUDA_WARP_SIZE );
578-   dim3  gridSize ((cols + blockSize.x  - 1 ) / blockSize.x , (rows + blockSize.y  - 1 ) / blockSize.y );
587+   dim3  blockSize (16 ,  16 );
588+   dim3  gridSize ((cols + blockSize.x  - 1 ) / blockSize.x , (rows + blockSize.y  - 1 ) / blockSize.y , batch );
579589
580590  const  auto  stream = cuda::getCurrentCUDAStream (self.device ().index ).stream ;
581-   kTriangle <T, LOWER><<<gridSize, blockSize, 0 , stream>>> (retPtr, selfPtr, rows, cols, diagonal);
591+   kTriangle <T, LOWER><<<gridSize, blockSize, 0 , stream>>> (retPtr, selfPtr, batch,  rows, cols, diagonal, matrixSize );
582592  CUDA_KERNEL_CHECK ();
583593  return  ret;
584594}
0 commit comments