@@ -374,100 +374,38 @@ def index(
374374 return reshape_output
375375
376376
377- def scatter_value (
377+ def scatter (
378378 ctx : ConversionContext ,
379379 target : Target ,
380380 source_ir : Optional [SourceIR ],
381381 name : str ,
382382 input : TRTTensor ,
383383 dim : int ,
384384 index : Union [TRTTensor , np .ndarray , torch .Tensor ],
385- value : float ,
385+ src : Union [ TRTTensor , int , float ] ,
386386) -> TRTTensor :
387- if not isinstance (input , TRTTensor ):
388- raise RuntimeError (
389- f"scatter_tensor received input { input } that is not part "
390- "of the TensorRT region!"
391- )
392387 input_shape = input .shape
393388 index_shape = index .shape
394389 index_shape_list = list (index .shape )
395390 if not (isinstance (index , TRTTensor )):
396391 index = get_trt_tensor (ctx , index , f"_index_tensor" )
397- if len (input_shape ) != len (index_shape ):
398- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
399392 dim = get_positive_dim (dim , len (input_shape ))
400393 dynamic_shape = has_dynamic_shape (input .shape )
401394 if dynamic_shape :
402395 # Check whether slice target dim is dynamic shape dim
403396 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
404397
405- input_dims = len (input_shape )
406- for i in range (0 , input_dims ):
407- if i != dim and (index_shape [i ] >= input .shape [i ]):
408- raise RuntimeError (
409- f"cannot have index size greater than the input size along dimension { dim } "
410- )
411-
412- value_tensor = get_trt_tensor (
413- ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
414- )
415- value_tensor = cast_trt_tensor (
416- ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
417- )
418- scatter_layer = ctx .net .add_scatter (
419- input , index , value_tensor , trt .ScatterMode .ELEMENT
420- )
421- scatter_layer .axis = dim
422- set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
423- out = scatter_layer .get_output (0 )
424- return out
425-
426-
427- def scatter_src (
428- ctx : ConversionContext ,
429- target : Target ,
430- source_ir : Optional [SourceIR ],
431- name : str ,
432- input : TRTTensor ,
433- dim : Shape ,
434- index : Shape ,
435- src : TRTTensor ,
436- ) -> TRTTensor :
437- if not isinstance (input , TRTTensor ):
438- raise RuntimeError (
439- f"scatter_tensor received input { input } that is not part "
440- "of the TensorRT region!"
441- )
442- input_shape = input .shape
443- index_shape = index .shape
444- src_shape = src .shape
445- if not (isinstance (index , TRTTensor )):
446- index = get_trt_tensor (ctx , index , f"_index_tensor" )
447- if len (input_shape ) != len (index_shape ):
448- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
449- if len (index_shape ) != len (src_shape ):
450- raise RuntimeError (f"The no of dimensions of src and index should be equal" )
451-
452- input_dims = len (input_shape )
453- dim = get_positive_dim (cast (int , dim ), input_dims )
454- dynamic_shape = has_dynamic_shape (input .shape )
455- if dynamic_shape :
456- # Check whether slice target dim is dynamic shape dim
457- assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
458-
459- for i in range (0 , input_dims ):
460- if i != dim and (index_shape [i ] >= input .shape [i ]):
461- raise RuntimeError (
462- f"cannot have index size greater than the input size along dimension { dim } "
463- )
464- input_dtype = input .dtype
465- # required for cases where src is a constant
466- src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
467- if input_dtype != src_dtype :
468- raise RuntimeError (f"The type of input and src should be made" )
469398 src_tensor = src
470- if not (isinstance (src , TRTTensor )):
399+ # scatter.value
400+ if isinstance (src , int ) or isinstance (src , float ):
401+ src_tensor = get_trt_tensor (
402+ ctx , src * torch .ones (index_shape_list ), name + "_value_tensor"
403+ )
404+ src_tensor = cast_trt_tensor (
405+ ctx , src_tensor , input .dtype , name + "_cast_value_tensor"
406+ )
407+ # scatter.src
408+ elif not (isinstance (src , TRTTensor )):
471409 src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
472410
473411 scatter_layer = ctx .net .add_scatter (
0 commit comments