11import logging
22from typing import Any , Callable , Dict , List , Optional
33
4+ import numpy as np
45import torch
56from torch ._decomp import register_decomposition
67from torch ._ops import OpOverload
@@ -174,7 +175,7 @@ def slice_scatter_decomposition(
174175 step : int ,
175176):
176177 dim_size = input_tensor .shape [dim ]
177- input_tensor_shape = input_tensor .shape
178+ # input_tensor_shape = input_tensor.shape
178179 if start is not None and start < 0 :
179180 start = start + dim_size
180181 if end is not None and end < 0 :
@@ -185,20 +186,22 @@ def slice_scatter_decomposition(
185186 end = dim_size
186187
187188 src_dim = list (src_tensor .shape ())
188- src_dim [dim ] = torch .floor_divide (end - start , step )
189- src = torch .expand (src , src_dim )
189+ step_dim = torch .floor_divide (end - start , step )
190+ # src = torch.expand(src, src_dim)
191+ end_dim = end
192+ if step_dim > src_dim [dim ]:
193+ end_dim = src_dim [dim ]
194+ else :
195+ src_tensor = src_tensor [0 :step_dim ]
190196
191- if ( start == 0 and end == dim_size and step == 0 ) :
197+ if start == 0 and end == dim_size and step == 0 :
192198 return input_tensor
193- mask = []
194- if start != 0 :
195- mask .append (torch .ge (input_tensor_shape , start ))
196- if end != dim_size :
197- mask .append (torch .ge (input_tensor_shape , end ))
198- if step != 1 :
199- mask .append (torch .eq (src_dim , 0 ))
200- src_val = torch .masked (mask , src_dim , 0 )
201- return torch .where (mask , src_val ,input_tensor )
199+ index_tensor = np .arange [start , end_dim , step ]
200+
201+ unbind_tensors = torch .unbind (input_tensor , dim )
202+ unbind_tensors [index_tensor ] = src_tensor
203+ return torch .cat (unbind_tensors , dim )
204+
202205
203206def get_decompositions (
204207 enable_experimental_decompositions : bool = False ,
0 commit comments