-
Notifications
You must be signed in to change notification settings - Fork 368
feat: support aten.roll dynamo converter #2569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5bc8314 to
5be128d
Compare
5be128d to
d58f865
Compare
| from .harness import DispatchTestCase | ||
|
|
||
|
|
||
| class TestRollConverter(DispatchTestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a case where both shifts and dims are single integers, which is a supported case in the docstring. These may be casted to lists in the operator before the converter ever gets them, but it is still a valid input I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review! I found the schema is:
- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
Does this mean shifts and dims should be a 1d list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
roll(tensor, 2, 3) --> roll(tensor, [2], [3])
pool(3) --> pool([3, 3])
To share additional documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive Thanks for the details!
Unfortunately, when testing shifts=2, dims=0, I got error:
File "<eval_with_key>.0 from /home/zewenl/TensorRT/tests/py/dynamo/conversion/test_roll_aten.py:35 in forward", line 5, in forward
roll_default = torch.ops.aten.roll.default(x, 2, 0); x = None
File "/home/zewenl/.local/lib/python3.10/site-packages/torch/_ops.py", line 571, in __call__
return self_._op(*args, **(kwargs or {}))
RuntimeError: aten::roll() Expected a value of type 'List[int]' for argument 'shifts' but instead found type 'int'.
Position: 1
Value: 2
Declaration: aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
Python error details: TypeError: 'int' object is not iterable
Then, I also tested shifts=(2,), dims=0, it works.
It seems that pytorch requires shifts to be a list.
According to the schema - func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor, I guess SymInt[1] and int[1] may have different behaviors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be so, yes, though that is strange - thanks for the update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm working on adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor converter, output_size expects List[int] as well:
RuntimeError: aten::adaptive_avg_pool2d() Expected a value of type 'List[int]' for argument 'output_size' but instead found type 'int'.
Position: 1
Value: 3
Declaration: aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
d58f865 to
6e19c81
Compare
6e19c81 to
a10ac64
Compare
Description
Support
aten.rolldynamo converter.Fixes #2567
Type of change
Checklist: