@@ -7348,6 +7348,82 @@ def test_no_label(self):
73487348 assert isinstance (out_img , tv_tensors .Image )
73497349 assert isinstance (out_boxes , tv_tensors .BoundingBoxes )
73507350
7351+ def test_semantic_masks_passthrough (self ):
7352+ # Test that semantic masks (2D) pass through unchanged
7353+ H , W = 256 , 128
7354+ boxes = tv_tensors .BoundingBoxes (
7355+ [[0 , 0 , 50 , 50 ], [60 , 60 , 100 , 100 ]],
7356+ format = tv_tensors .BoundingBoxFormat .XYXY ,
7357+ canvas_size = (H , W ),
7358+ )
7359+
7360+ # Create semantic segmentation mask (H, W) - should NOT be sanitized
7361+ semantic_mask = tv_tensors .Mask (torch .randint (0 , 10 , size = (H , W )))
7362+
7363+ sample = {
7364+ "boxes" : boxes ,
7365+ "semantic_mask" : semantic_mask ,
7366+ }
7367+
7368+ out = transforms .SanitizeBoundingBoxes (labels_getter = None )(sample )
7369+
7370+ # Check that semantic mask passed through unchanged
7371+ assert isinstance (out ["semantic_mask" ], tv_tensors .Mask )
7372+ assert out ["semantic_mask" ].shape == (H , W )
7373+ assert_equal (out ["semantic_mask" ], semantic_mask )
7374+
7375+ def test_masks_with_mismatched_shape_passthrough (self ):
7376+ # Test that masks with shapes that don't match the number of boxes are passed through
7377+ H , W = 256 , 128
7378+ boxes = tv_tensors .BoundingBoxes (
7379+ [[0 , 0 , 10 , 10 ], [20 , 20 , 30 , 30 ], [50 , 50 , 60 , 60 ]],
7380+ format = tv_tensors .BoundingBoxFormat .XYXY ,
7381+ canvas_size = (H , W ),
7382+ )
7383+
7384+ # Create masks with different number of instances than boxes
7385+ mismatched_masks = tv_tensors .Mask (torch .randint (0 , 2 , size = (5 , H , W ))) # 5 masks but 3 boxes
7386+
7387+ sample = {
7388+ "boxes" : boxes ,
7389+ "masks" : mismatched_masks ,
7390+ }
7391+
7392+ # Should not raise an error, masks should pass through unchanged
7393+ out = transforms .SanitizeBoundingBoxes (labels_getter = None )(sample )
7394+
7395+ assert isinstance (out ["masks" ], tv_tensors .Mask )
7396+ assert out ["masks" ].shape == (5 , H , W )
7397+ assert_equal (out ["masks" ], mismatched_masks )
7398+
7399+ def test_per_instance_masks_sanitized (self ):
7400+ # Test that per-instance masks (N, H, W) are correctly sanitized
7401+ H , W = 256 , 128
7402+ boxes , expected_valid_mask = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = 10 , min_area = 10 )
7403+ valid_indices = [i for (i , is_valid ) in enumerate (expected_valid_mask ) if is_valid ]
7404+ num_boxes = boxes .shape [0 ]
7405+
7406+ # Create per-instance masks matching the number of boxes
7407+ per_instance_masks = tv_tensors .Mask (torch .randint (0 , 2 , size = (num_boxes , H , W )))
7408+ labels = torch .arange (num_boxes )
7409+
7410+ sample = {
7411+ "boxes" : boxes ,
7412+ "masks" : per_instance_masks ,
7413+ "labels" : labels ,
7414+ }
7415+
7416+ out = transforms .SanitizeBoundingBoxes (min_size = 10 , min_area = 10 )(sample )
7417+
7418+ # Check that masks were sanitized correctly
7419+ assert isinstance (out ["masks" ], tv_tensors .Mask )
7420+ assert out ["masks" ].shape [0 ] == len (valid_indices )
7421+ assert out ["masks" ].shape [0 ] == out ["boxes" ].shape [0 ] == out ["labels" ].shape [0 ]
7422+
7423+ # Verify correct masks were kept
7424+ for i , valid_idx in enumerate (valid_indices ):
7425+ assert_equal (out ["masks" ][i ], per_instance_masks [valid_idx ])
7426+
73517427 def test_errors_transform (self ):
73527428 good_bbox = tv_tensors .BoundingBoxes (
73537429 [[0 , 0 , 10 , 10 ]],
0 commit comments