2121 set_layer_name ,
2222)
2323from torch_tensorrt .fx .types import Shape , TRTTensor
24+ from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
2425
2526_LOGGER : logging .Logger = logging .getLogger (__name__ )
2627
@@ -398,8 +399,8 @@ def scatter_value(
398399 source_ir : Optional [SourceIR ],
399400 name : str ,
400401 input : TRTTensor ,
401- dim : Shape ,
402- index : Shape ,
402+ dim : int ,
403+ index : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
403404 value : float ,
404405) -> TRTTensor :
405406 if not isinstance (input , TRTTensor ):
@@ -409,26 +410,34 @@ def scatter_value(
409410 )
410411 input_shape = input .shape
411412 index_shape = index .shape
413+ index_shape_list = list (index .shape )
414+ if not (isinstance (index , TRTTensor )):
415+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
412416 if len (input_shape ) != len (index_shape ):
413417 raise RuntimeError (f"The no of dimensions of input and index should be equal" )
414- ranks = len (input_shape )
415- dim = get_positive_dim (cast (int , dim ), ranks )
418+ dim = get_positive_dim (dim , len (input_shape ))
416419 dynamic_shape = has_dynamic_shape (input .shape )
417420 if dynamic_shape :
418421 # Check whether slice target dim is dynamic shape dim
419422 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
420423
421- input_dims = len (input . shape )
424+ input_dims = len (input_shape )
422425 for i in range (0 , input_dims ):
423- if index [i ] >= input .shape [i ]:
426+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
424427 raise RuntimeError (
425- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
428+ f"cannot have index size greater than the input size along dimension { dim } "
426429 )
427- value_tensor = value * torch .ones (index .shape )
430+
431+ value_tensor = get_trt_tensor (
432+ ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
433+ )
434+ value_tensor = cast_trt_tensor (
435+ ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
436+ )
428437 scatter_layer = ctx .net .add_scatter (
429- input , index , value_tensor , trt .tensorrt . ScatterModekELEMENT
438+ input , index , value_tensor , trt .ScatterMode . ELEMENT
430439 )
431- scatter_layer .set_axis ( dim )
440+ scatter_layer .axis = dim
432441 set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
433442 out = scatter_layer .get_output (0 )
434443 return out
@@ -452,6 +461,8 @@ def scatter_src(
452461 input_shape = input .shape
453462 index_shape = index .shape
454463 src_shape = src .shape
464+ if not (isinstance (index , TRTTensor )):
465+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
455466 if len (input_shape ) != len (index_shape ):
456467 raise RuntimeError (f"The no of dimensions of input and index should be equal" )
457468 if len (index_shape ) != len (src_shape ):
@@ -465,14 +476,23 @@ def scatter_src(
465476 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
466477
467478 for i in range (0 , input_dims ):
468- if index [i ] >= input .shape [i ]:
479+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
469480 raise RuntimeError (
470- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
481+ f"cannot have index size greater than the input size along dimension { dim } "
471482 )
483+ input_dtype = input .dtype
484+ # required for cases where src is a constant
485+ src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
486+ if input_dtype != src_dtype :
487+ raise RuntimeError (f"The type of input and src should be made" )
488+ src_tensor = src
489+ if not (isinstance (src , TRTTensor )):
490+ src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
491+
472492 scatter_layer = ctx .net .add_scatter (
473- input , index , src , trt .tensorrt . ScatterModekELEMENT
493+ input , index , src_tensor , trt .ScatterMode . ELEMENT
474494 )
475- scatter_layer .set_axis ( dim )
495+ scatter_layer .axis = dim
476496 set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
477497 out = scatter_layer .get_output (0 )
478498 return out
0 commit comments