|
11 | 11 | import torch |
12 | 12 | from torchcodec.decoders import AudioDecoder |
13 | 13 |
|
14 | | -from torchcodec.encoders import AudioEncoder |
| 14 | +from torchcodec.encoders import AudioEncoder, VideoEncoder |
15 | 15 |
|
16 | 16 | from .utils import ( |
17 | 17 | assert_tensor_close_on_at_least, |
@@ -564,3 +564,115 @@ def write(self, data): |
564 | 564 | RuntimeError, match="File like object must implement a seek method" |
565 | 565 | ): |
566 | 566 | encoder.to_file_like(NoSeekMethod(), format="wav") |
| 567 | + |
| 568 | + |
| 569 | +class TestVideoEncoder: |
| 570 | + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) |
| 571 | + def test_bad_input_parameterized(self, tmp_path, method): |
| 572 | + if method == "to_file": |
| 573 | + valid_params = dict(dest=str(tmp_path / "output.mp4")) |
| 574 | + elif method == "to_tensor": |
| 575 | + valid_params = dict(format="mp4") |
| 576 | + elif method == "to_file_like": |
| 577 | + valid_params = dict(file_like=io.BytesIO(), format="mp4") |
| 578 | + else: |
| 579 | + raise ValueError(f"Unknown method: {method}") |
| 580 | + |
| 581 | + with pytest.raises( |
| 582 | + ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32" |
| 583 | + ): |
| 584 | + encoder = VideoEncoder( |
| 585 | + frames=torch.rand(5, 3, 64, 64), |
| 586 | + frame_rate=30, |
| 587 | + ) |
| 588 | + getattr(encoder, method)(**valid_params) |
| 589 | + |
| 590 | + with pytest.raises( |
| 591 | + ValueError, match=r"Expected 4D frames, got frames.shape = torch.Size" |
| 592 | + ): |
| 593 | + encoder = VideoEncoder( |
| 594 | + frames=torch.zeros(10), |
| 595 | + frame_rate=30, |
| 596 | + ) |
| 597 | + getattr(encoder, method)(**valid_params) |
| 598 | + |
| 599 | + with pytest.raises( |
| 600 | + RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2" |
| 601 | + ): |
| 602 | + encoder = VideoEncoder( |
| 603 | + frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8), |
| 604 | + frame_rate=30, |
| 605 | + ) |
| 606 | + getattr(encoder, method)(**valid_params) |
| 607 | + |
| 608 | + def test_bad_input(self, tmp_path): |
| 609 | + encoder = VideoEncoder( |
| 610 | + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), |
| 611 | + frame_rate=30, |
| 612 | + ) |
| 613 | + |
| 614 | + with pytest.raises( |
| 615 | + RuntimeError, |
| 616 | + match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?", |
| 617 | + ): |
| 618 | + encoder.to_file("./file.bad_extension") |
| 619 | + |
| 620 | + with pytest.raises( |
| 621 | + RuntimeError, |
| 622 | + match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?", |
| 623 | + ): |
| 624 | + encoder.to_file("./bad/path.mp3") |
| 625 | + |
| 626 | + with pytest.raises( |
| 627 | + RuntimeError, |
| 628 | + match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format", |
| 629 | + ): |
| 630 | + encoder.to_tensor(format="bad_format") |
| 631 | + |
| 632 | + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) |
| 633 | + def test_contiguity(self, method, tmp_path): |
| 634 | + # Ensure that 2 sets of video frames with the same pixel values are encoded |
| 635 | + # in the same way, regardless of their memory layout. Here we encode 2 equal |
| 636 | + # frame tensors, one is contiguous while the other is non-contiguous. |
| 637 | + |
| 638 | + num_frames, channels, height, width = 5, 3, 64, 64 |
| 639 | + contiguous_frames = torch.randint( |
| 640 | + 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 |
| 641 | + ).contiguous() |
| 642 | + assert contiguous_frames.is_contiguous() |
| 643 | + |
| 644 | + # Permute NCHW to NHWC, then update the memory layout, then permute back |
| 645 | + non_contiguous_frames = ( |
| 646 | + contiguous_frames.permute(0, 2, 3, 1).contiguous().permute(0, 3, 1, 2) |
| 647 | + ) |
| 648 | + assert non_contiguous_frames.stride() != contiguous_frames.stride() |
| 649 | + assert not non_contiguous_frames.is_contiguous() |
| 650 | + assert non_contiguous_frames.is_contiguous(memory_format=torch.channels_last) |
| 651 | + |
| 652 | + torch.testing.assert_close( |
| 653 | + contiguous_frames, non_contiguous_frames, rtol=0, atol=0 |
| 654 | + ) |
| 655 | + |
| 656 | + def encode_to_tensor(frames): |
| 657 | + if method == "to_file": |
| 658 | + dest = str(tmp_path / "output.mp4") |
| 659 | + VideoEncoder(frames, frame_rate=30).to_file(dest=dest) |
| 660 | + with open(dest, "rb") as f: |
| 661 | + return torch.frombuffer(f.read(), dtype=torch.uint8) |
| 662 | + elif method == "to_tensor": |
| 663 | + return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4") |
| 664 | + elif method == "to_file_like": |
| 665 | + file_like = io.BytesIO() |
| 666 | + VideoEncoder(frames, frame_rate=30).to_file_like( |
| 667 | + file_like, format="mp4" |
| 668 | + ) |
| 669 | + return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) |
| 670 | + else: |
| 671 | + raise ValueError(f"Unknown method: {method}") |
| 672 | + |
| 673 | + encoded_from_contiguous = encode_to_tensor(contiguous_frames) |
| 674 | + encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames) |
| 675 | + |
| 676 | + torch.testing.assert_close( |
| 677 | + encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 |
| 678 | + ) |
0 commit comments