@@ -68,7 +68,7 @@ def select(
6868 indices_tensor = ctx .net .add_constant (
6969 index_value .shape , to_numpy (index_value )
7070 ).get_output (0 )
71- out = gather (input , indices_tensor , dim )
71+ out = gather (ctx , target , source_ir , name , input , indices_tensor , dim )
7272 if len (out .shape ) != 1 :
7373 layer = ctx .net .add_shuffle (out )
7474 return layer .get_output (0 )
@@ -140,7 +140,7 @@ def index(
140140 )
141141 index = adv_indx_indices [0 ]
142142 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
143- return gather (input , index , indices_tensor )
143+ return gather (ctx , target , source_ir , name , input , index , indices_tensor )
144144 else :
145145 input_shape = input .shape
146146 _LOGGER .debug (f"The input shape is { input .shape } " )
@@ -253,7 +253,7 @@ def index(
253253 dim_tensor_list [adv_indx_indices [i ]],
254254 )
255255
256- gather_out = gather (flatten_tensor , cum_adv_index , 0 )
256+ gather_out = gather (ctx , target , source_ir , name , flatten_tensor , 0 , cum_adv_index )
257257 _LOGGER .debug (f"The shape after cumultative gather is { gather_out .shape } " )
258258 _LOGGER .debug (f"The shape for cumulative adv index is { cum_adv_index } " )
259259
0 commit comments