|
5 | 5 | import torch |
6 | 6 | from torch._decomp import register_decomposition |
7 | 7 | from torch._ops import OpOverload |
8 | | -from torch_tensorrt.dynamo.conversion.converter_utils import ( |
9 | | - get_positive_dim, |
10 | | -) |
| 8 | +from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim |
11 | 9 |
|
12 | 10 | from ._decomposition_groups import ( |
13 | 11 | ENABLED_TORCH_DECOMPOSITIONS, |
@@ -187,21 +185,22 @@ def slice_scatter_decomposition( |
187 | 185 | if step_dim > src_dim[dim]: |
188 | 186 | end_dim = src_dim[dim] |
189 | 187 | else: |
190 | | - indices = torch.Tensor(torch.arange(0, step_dim)) |
| 188 | + indices = torch.arange(0, step_dim) |
191 | 189 | indices = indices.to(torch.int32) |
192 | 190 | src = torch.index_select(src, dim, indices) |
193 | 191 |
|
194 | 192 | if start == 0 and end == dim_size and step == 0: |
195 | 193 | return input_tensor |
196 | | - index_tensor = torch.arange(start, end_dim, step) |
197 | 194 |
|
198 | 195 | unbind_input_tensors = torch.unbind(input_tensor, dim) |
199 | 196 | unbind_input_tensors_list = list(unbind_input_tensors) |
200 | 197 | unbind_source_tensors = torch.unbind(src, dim) |
201 | 198 | unbind_source_tensors_list = list(unbind_source_tensors) |
202 | 199 |
|
203 | | - for i, index in enumerate(index_tensor): |
| 200 | + i = 0 |
| 201 | + for index in range(start, end_dim, step): |
204 | 202 | unbind_input_tensors_list[index] = unbind_source_tensors_list[i] |
| 203 | + i = i + 1 |
205 | 204 |
|
206 | 205 | return torch.stack(unbind_input_tensors_list, dim) |
207 | 206 |
|
|
0 commit comments