@@ -393,100 +393,38 @@ def index_select(
393393 return gather_layer .get_output (0 )
394394
395395
396- def scatter_value (
396+ def scatter (
397397 ctx : ConversionContext ,
398398 target : Target ,
399399 source_ir : Optional [SourceIR ],
400400 name : str ,
401401 input : TRTTensor ,
402402 dim : int ,
403403 index : Union [TRTTensor , np .ndarray , torch .Tensor ],
404- value : float ,
404+ src : Union [ TRTTensor , int , float ] ,
405405) -> TRTTensor :
406- if not isinstance (input , TRTTensor ):
407- raise RuntimeError (
408- f"scatter_tensor received input { input } that is not part "
409- "of the TensorRT region!"
410- )
411406 input_shape = input .shape
412407 index_shape = index .shape
413408 index_shape_list = list (index .shape )
414409 if not (isinstance (index , TRTTensor )):
415410 index = get_trt_tensor (ctx , index , f"_index_tensor" )
416- if len (input_shape ) != len (index_shape ):
417- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
418411 dim = get_positive_dim (dim , len (input_shape ))
419412 dynamic_shape = has_dynamic_shape (input .shape )
420413 if dynamic_shape :
421414 # Check whether slice target dim is dynamic shape dim
422415 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
423416
424- input_dims = len (input_shape )
425- for i in range (0 , input_dims ):
426- if i != dim and (index_shape [i ] >= input .shape [i ]):
427- raise RuntimeError (
428- f"cannot have index size greater than the input size along dimension { dim } "
429- )
430-
431- value_tensor = get_trt_tensor (
432- ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
433- )
434- value_tensor = cast_trt_tensor (
435- ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
436- )
437- scatter_layer = ctx .net .add_scatter (
438- input , index , value_tensor , trt .ScatterMode .ELEMENT
439- )
440- scatter_layer .axis = dim
441- set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
442- out = scatter_layer .get_output (0 )
443- return out
444-
445-
446- def scatter_src (
447- ctx : ConversionContext ,
448- target : Target ,
449- source_ir : Optional [SourceIR ],
450- name : str ,
451- input : TRTTensor ,
452- dim : Shape ,
453- index : Shape ,
454- src : TRTTensor ,
455- ) -> TRTTensor :
456- if not isinstance (input , TRTTensor ):
457- raise RuntimeError (
458- f"scatter_tensor received input { input } that is not part "
459- "of the TensorRT region!"
460- )
461- input_shape = input .shape
462- index_shape = index .shape
463- src_shape = src .shape
464- if not (isinstance (index , TRTTensor )):
465- index = get_trt_tensor (ctx , index , f"_index_tensor" )
466- if len (input_shape ) != len (index_shape ):
467- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
468- if len (index_shape ) != len (src_shape ):
469- raise RuntimeError (f"The no of dimensions of src and index should be equal" )
470-
471- input_dims = len (input_shape )
472- dim = get_positive_dim (cast (int , dim ), input_dims )
473- dynamic_shape = has_dynamic_shape (input .shape )
474- if dynamic_shape :
475- # Check whether slice target dim is dynamic shape dim
476- assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
477-
478- for i in range (0 , input_dims ):
479- if i != dim and (index_shape [i ] >= input .shape [i ]):
480- raise RuntimeError (
481- f"cannot have index size greater than the input size along dimension { dim } "
482- )
483- input_dtype = input .dtype
484- # required for cases where src is a constant
485- src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
486- if input_dtype != src_dtype :
487- raise RuntimeError (f"The type of input and src should be made" )
488417 src_tensor = src
489- if not (isinstance (src , TRTTensor )):
418+ # scatter.value
419+ if isinstance (src , int ) or isinstance (src , float ):
420+ src_tensor = get_trt_tensor (
421+ ctx , src * torch .ones (index_shape_list ), name + "_value_tensor"
422+ )
423+ src_tensor = cast_trt_tensor (
424+ ctx , src_tensor , input .dtype , name + "_cast_value_tensor"
425+ )
426+ # scatter.src
427+ elif not (isinstance (src , TRTTensor )):
490428 src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
491429
492430 scatter_layer = ctx .net .add_scatter (
0 commit comments