@@ -68,13 +68,26 @@ def select(
6868 indices_tensor = ctx .net .add_constant (
6969 index_value .shape , to_numpy (index_value )
7070 ).get_output (0 )
71- layer = ctx .net .add_gather (input , indices_tensor , dim )
72- out = layer .get_output (0 )
71+ out = gather (input , indices_tensor , dim )
7372 if len (out .shape ) != 1 :
7473 layer = ctx .net .add_shuffle (out )
7574 return layer .get_output (0 )
7675
7776
77+ def gather (
78+ ctx : ConversionContext ,
79+ target : Target ,
80+ source_ir : Optional [SourceIR ],
81+ name : str ,
82+ input : TRTTensor ,
83+ dim : int ,
84+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
85+ ) -> TRTTensor :
86+ gather_layer = ctx .net .add_gather (input , index , dim )
87+ set_layer_name (gather_layer , target , name + "_gather" , source_ir )
88+ return gather_layer .get_output (0 )
89+
90+
7891def index (
7992 ctx : ConversionContext ,
8093 target : Target ,
@@ -127,9 +140,7 @@ def index(
127140 )
128141 index = adv_indx_indices [0 ]
129142 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
130- gather_layer = ctx .net .add_gather (input , indices_tensor , index )
131- set_layer_name (gather_layer , target , name + "_index_gather" , source_ir )
132- return gather_layer .get_output (0 )
143+ return gather (input , index , indices_tensor )
133144 else :
134145 input_shape = input .shape
135146 _LOGGER .debug (f"The input shape is { input .shape } " )
@@ -242,11 +253,7 @@ def index(
242253 dim_tensor_list [adv_indx_indices [i ]],
243254 )
244255
245- gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
246- set_layer_name (
247- gather_layer_element , target , name + "_index_gather_element" , source_ir
248- )
249- gather_out = gather_layer_element .get_output (0 )
256+ gather_out = gather (flatten_tensor , cum_adv_index , 0 )
250257 _LOGGER .debug (f"The shape after cumultative gather is { gather_out .shape } " )
251258 _LOGGER .debug (f"The shape for cumulative adv index is { cum_adv_index } " )
252259
0 commit comments