Skip to content

Commit a745362

Browse files
committed
addressing review comments and changing test names
1 parent 975ce31 commit a745362

File tree

3 files changed

+26
-94
lines changed

3 files changed

+26
-94
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -697,28 +697,21 @@ def aten_ops_clamp(
697697
)
698698

699699

700-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
701-
def aten_ops_scatter_value(
702-
ctx: ConversionContext,
703-
target: Target,
704-
args: Tuple[Argument, ...],
705-
kwargs: Dict[str, Argument],
706-
name: str,
707-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
708-
return impl.select.scatter_value(
709-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
710-
)
711-
712-
700+
@enforce_tensor_types(
701+
{
702+
0: (TRTTensor,),
703+
}
704+
)
713705
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
714-
def aten_ops_scatter_src(
706+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
707+
def aten_ops_scatter(
715708
ctx: ConversionContext,
716709
target: Target,
717710
args: Tuple[Argument, ...],
718711
kwargs: Dict[str, Argument],
719712
name: str,
720713
) -> Union[TRTTensor, Sequence[TRTTensor]]:
721-
return impl.select.scatter_src(
714+
return impl.select.scatter(
722715
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
723716
)
724717

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 12 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -374,100 +374,38 @@ def index(
374374
return reshape_output
375375

376376

377-
def scatter_value(
377+
def scatter(
378378
ctx: ConversionContext,
379379
target: Target,
380380
source_ir: Optional[SourceIR],
381381
name: str,
382382
input: TRTTensor,
383383
dim: int,
384384
index: Union[TRTTensor, np.ndarray, torch.Tensor],
385-
value: float,
385+
src: Union[TRTTensor, int, float],
386386
) -> TRTTensor:
387-
if not isinstance(input, TRTTensor):
388-
raise RuntimeError(
389-
f"scatter_tensor received input {input} that is not part "
390-
"of the TensorRT region!"
391-
)
392387
input_shape = input.shape
393388
index_shape = index.shape
394389
index_shape_list = list(index.shape)
395390
if not (isinstance(index, TRTTensor)):
396391
index = get_trt_tensor(ctx, index, f"_index_tensor")
397-
if len(input_shape) != len(index_shape):
398-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
399392
dim = get_positive_dim(dim, len(input_shape))
400393
dynamic_shape = has_dynamic_shape(input.shape)
401394
if dynamic_shape:
402395
# Check whether slice target dim is dynamic shape dim
403396
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
404397

405-
input_dims = len(input_shape)
406-
for i in range(0, input_dims):
407-
if i != dim and (index_shape[i] >= input.shape[i]):
408-
raise RuntimeError(
409-
f"cannot have index size greater than the input size along dimension {dim}"
410-
)
411-
412-
value_tensor = get_trt_tensor(
413-
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
414-
)
415-
value_tensor = cast_trt_tensor(
416-
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
417-
)
418-
scatter_layer = ctx.net.add_scatter(
419-
input, index, value_tensor, trt.ScatterMode.ELEMENT
420-
)
421-
scatter_layer.axis = dim
422-
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
423-
out = scatter_layer.get_output(0)
424-
return out
425-
426-
427-
def scatter_src(
428-
ctx: ConversionContext,
429-
target: Target,
430-
source_ir: Optional[SourceIR],
431-
name: str,
432-
input: TRTTensor,
433-
dim: Shape,
434-
index: Shape,
435-
src: TRTTensor,
436-
) -> TRTTensor:
437-
if not isinstance(input, TRTTensor):
438-
raise RuntimeError(
439-
f"scatter_tensor received input {input} that is not part "
440-
"of the TensorRT region!"
441-
)
442-
input_shape = input.shape
443-
index_shape = index.shape
444-
src_shape = src.shape
445-
if not (isinstance(index, TRTTensor)):
446-
index = get_trt_tensor(ctx, index, f"_index_tensor")
447-
if len(input_shape) != len(index_shape):
448-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
449-
if len(index_shape) != len(src_shape):
450-
raise RuntimeError(f"The no of dimensions of src and index should be equal")
451-
452-
input_dims = len(input_shape)
453-
dim = get_positive_dim(cast(int, dim), input_dims)
454-
dynamic_shape = has_dynamic_shape(input.shape)
455-
if dynamic_shape:
456-
# Check whether slice target dim is dynamic shape dim
457-
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
458-
459-
for i in range(0, input_dims):
460-
if i != dim and (index_shape[i] >= input.shape[i]):
461-
raise RuntimeError(
462-
f"cannot have index size greater than the input size along dimension {dim}"
463-
)
464-
input_dtype = input.dtype
465-
# required for cases where src is a constant
466-
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
467-
if input_dtype != src_dtype:
468-
raise RuntimeError(f"The type of input and src should be made")
469398
src_tensor = src
470-
if not (isinstance(src, TRTTensor)):
399+
# scatter.value
400+
if isinstance(src, int) or isinstance(src, float):
401+
src_tensor = get_trt_tensor(
402+
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
403+
)
404+
src_tensor = cast_trt_tensor(
405+
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
406+
)
407+
# scatter.src
408+
elif not (isinstance(src, TRTTensor)):
471409
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
472410

473411
scatter_layer = ctx.net.add_scatter(

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from harness import DispatchTestCase
32
from parameterized import parameterized
43
from torch.testing._internal.common_utils import run_tests
54
from torch_tensorrt import Input
65

6+
from .harness import DispatchTestCase
7+
78

89
class TestScatterValueConverter(DispatchTestCase):
910
@parameterized.expand(
@@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase):
8788
@parameterized.expand(
8889
[
8990
(
90-
"scatter_zero_dim_indexOne_constant_src",
91+
"scatter_zero_dim_indexOne_src",
9192
0,
9293
torch.tensor([[0, 1, 2, 0]]),
9394
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
9495
),
9596
(
96-
"scatter_zero_dim_indexTwo_constant_src",
97+
"scatter_zero_dim_indexTwo_src",
9798
0,
9899
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
99100
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
100101
),
101102
(
102-
"scatter_one_dim_indexOne_constant_src",
103+
"scatter_one_dim_indexOne_src",
103104
1,
104105
torch.tensor([[0, 1, 2, 0]]),
105106
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
106107
),
107108
(
108-
"scatter_one_dim_indexTwo_constant_src",
109+
"scatter_one_dim_indexTwo_src",
109110
1,
110111
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
111112
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),

0 commit comments

Comments
 (0)