Skip to content

Commit d59a9a0

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] prevent unwrapping in SanitizeBoundingBoxes (#7446)
Reviewed By: vmoens Differential Revision: D44416561 fbshipit-source-id: 8a782227bee116efaed2197365e8ccc30787ec70
1 parent c4364b1 commit d59a9a0

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

test/test_transforms_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,9 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
20202020
assert out_image is input_img
20212021
assert out_whatever is whatever
20222022

2023+
assert isinstance(out_boxes, datapoints.BoundingBox)
2024+
assert isinstance(out_masks, datapoints.Mask)
2025+
20232026
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
20242027
assert out_labels is labels
20252028
else:

torchvision/transforms/v2/_misc.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,15 @@ def forward(self, *inputs: Any) -> Any:
397397
return tree_unflatten(flat_outputs, spec)
398398

399399
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
400+
is_label = inpt is not None and inpt is params["labels"]
401+
is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask))
400402

401-
if (inpt is not None and inpt is params["labels"]) or isinstance(
402-
inpt, (datapoints.BoundingBox, datapoints.Mask)
403-
):
404-
inpt = inpt[params["valid"]]
403+
if not (is_label or is_bounding_box_or_mask):
404+
return inpt
405405

406-
return inpt
406+
output = inpt[params["valid"]]
407+
408+
if is_label:
409+
return output
410+
411+
return type(inpt).wrap_like(inpt, output)

0 commit comments

Comments
 (0)