Skip to content

Commit cbb7ae1

Browse files
committed
changing decomposition pattern
1 parent 50dbeab commit cbb7ae1

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Any, Callable, Dict, List, Optional
33

4+
import numpy as np
45
import torch
56
from torch._decomp import register_decomposition
67
from torch._ops import OpOverload
@@ -174,7 +175,7 @@ def slice_scatter_decomposition(
174175
step: int,
175176
):
176177
dim_size = input_tensor.shape[dim]
177-
input_tensor_shape = input_tensor.shape
178+
# input_tensor_shape = input_tensor.shape
178179
if start is not None and start < 0:
179180
start = start + dim_size
180181
if end is not None and end < 0:
@@ -185,20 +186,22 @@ def slice_scatter_decomposition(
185186
end = dim_size
186187

187188
src_dim = list(src_tensor.shape())
188-
src_dim[dim] = torch.floor_divide(end - start, step)
189-
src = torch.expand(src, src_dim)
189+
step_dim = torch.floor_divide(end - start, step)
190+
# src = torch.expand(src, src_dim)
191+
end_dim = end
192+
if step_dim > src_dim[dim]:
193+
end_dim = src_dim[dim]
194+
else:
195+
src_tensor = src_tensor[0:step_dim]
190196

191-
if (start == 0 and end == dim_size and step == 0):
197+
if start == 0 and end == dim_size and step == 0:
192198
return input_tensor
193-
mask = []
194-
if start != 0:
195-
mask.append(torch.ge(input_tensor_shape, start))
196-
if end != dim_size:
197-
mask.append(torch.ge(input_tensor_shape, end))
198-
if step != 1:
199-
mask.append(torch.eq(src_dim, 0))
200-
src_val = torch.masked(mask, src_dim, 0)
201-
return torch.where(mask, src_val,input_tensor)
199+
index_tensor = np.arange[start, end_dim, step]
200+
201+
unbind_tensors = torch.unbind(input_tensor, dim)
202+
unbind_tensors[index_tensor] = src_tensor
203+
return torch.cat(unbind_tensors, dim)
204+
202205

203206
def get_decompositions(
204207
enable_experimental_decompositions: bool = False,

0 commit comments

Comments
 (0)