Skip to content

Commit b12b6f4

Browse files
committed
Removing arange and replacing with range
1 parent 1727f2d commit b12b6f4

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import torch
66
from torch._decomp import register_decomposition
77
from torch._ops import OpOverload
8-
from torch_tensorrt.dynamo.conversion.converter_utils import (
9-
get_positive_dim,
10-
)
8+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
119

1210
from ._decomposition_groups import (
1311
ENABLED_TORCH_DECOMPOSITIONS,
@@ -187,21 +185,22 @@ def slice_scatter_decomposition(
187185
if step_dim > src_dim[dim]:
188186
end_dim = src_dim[dim]
189187
else:
190-
indices = torch.Tensor(torch.arange(0, step_dim))
188+
indices = torch.arange(0, step_dim)
191189
indices = indices.to(torch.int32)
192190
src = torch.index_select(src, dim, indices)
193191

194192
if start == 0 and end == dim_size and step == 0:
195193
return input_tensor
196-
index_tensor = torch.arange(start, end_dim, step)
197194

198195
unbind_input_tensors = torch.unbind(input_tensor, dim)
199196
unbind_input_tensors_list = list(unbind_input_tensors)
200197
unbind_source_tensors = torch.unbind(src, dim)
201198
unbind_source_tensors_list = list(unbind_source_tensors)
202199

203-
for i, index in enumerate(index_tensor):
200+
i = 0
201+
for index in range(start, end_dim, step):
204202
unbind_input_tensors_list[index] = unbind_source_tensors_list[i]
203+
i = i + 1
205204

206205
return torch.stack(unbind_input_tensors_list, dim)
207206

0 commit comments

Comments
 (0)