Skip to content

Commit 262c457

Browse files
authored
Decoder-native transforms benchmark (#982)
1 parent c4bca2d commit 262c457

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import math
2+
from argparse import ArgumentParser
3+
from pathlib import Path
4+
from time import perf_counter_ns
5+
6+
import torch
7+
from torch import Tensor
8+
from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts
9+
from torchcodec.decoders import VideoDecoder
10+
from torchvision.transforms import v2
11+
12+
DEFAULT_NUM_EXP = 20
13+
14+
15+
def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor:
16+
17+
for _ in range(warmup):
18+
f(*args)
19+
20+
times = []
21+
for _ in range(num_exp):
22+
start = perf_counter_ns()
23+
f(*args)
24+
end = perf_counter_ns()
25+
times.append(end - start)
26+
return torch.tensor(times).float()
27+
28+
29+
def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
30+
mul = {
31+
"ns": 1,
32+
"µs": 1e-3,
33+
"ms": 1e-6,
34+
"s": 1e-9,
35+
}[unit]
36+
times = times * mul
37+
std = times.std().item()
38+
med = times.median().item()
39+
mean = times.mean().item()
40+
min = times.min().item()
41+
max = times.max().item()
42+
print(
43+
f"{prefix:<45} {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min = :.2f}, {max = :.2f} - in {unit}"
44+
)
45+
46+
47+
def torchvision_resize(
48+
path: Path, pts_seconds: list[float], dims: tuple[int, int]
49+
) -> None:
50+
decoder = create_from_file(str(path), seek_mode="approximate")
51+
add_video_stream(decoder)
52+
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
53+
return v2.functional.resize(raw_frames, size=dims)
54+
55+
56+
def torchvision_crop(
57+
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
58+
) -> None:
59+
decoder = create_from_file(str(path), seek_mode="approximate")
60+
add_video_stream(decoder)
61+
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
62+
return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1])
63+
64+
65+
def decoder_native_resize(
66+
path: Path, pts_seconds: list[float], dims: tuple[int, int]
67+
) -> None:
68+
decoder = create_from_file(str(path), seek_mode="approximate")
69+
add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}")
70+
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
71+
72+
73+
def decoder_native_crop(
74+
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
75+
) -> None:
76+
decoder = create_from_file(str(path), seek_mode="approximate")
77+
add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}")
78+
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
79+
80+
81+
def main():
82+
parser = ArgumentParser()
83+
parser.add_argument("--path", type=str, help="path to file", required=True)
84+
parser.add_argument(
85+
"--num-exp",
86+
type=int,
87+
default=DEFAULT_NUM_EXP,
88+
help="number of runs to average over",
89+
)
90+
91+
args = parser.parse_args()
92+
path = Path(args.path)
93+
94+
metadata = VideoDecoder(path).metadata
95+
duration = metadata.duration_seconds
96+
97+
print(
98+
f"Benchmarking {path.name}, duration: {duration}, codec: {metadata.codec}, averaging over {args.num_exp} runs:"
99+
)
100+
101+
input_height = metadata.height
102+
input_width = metadata.width
103+
fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1]
104+
fraction_of_input_dimensions = [0.5, 0.25, 0.125]
105+
106+
for num_fraction in fraction_of_total_frames_to_sample:
107+
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
108+
print(
109+
f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
110+
)
111+
uniform_timestamps = [
112+
i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
113+
]
114+
115+
for dims_fraction in fraction_of_input_dimensions:
116+
dims = (int(input_height * dims_fraction), int(input_width * dims_fraction))
117+
118+
times = bench(
119+
torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp
120+
)
121+
report_stats(times, prefix=f"torchvision_resize({dims})")
122+
123+
times = bench(
124+
decoder_native_resize,
125+
path,
126+
uniform_timestamps,
127+
dims,
128+
num_exp=args.num_exp,
129+
)
130+
report_stats(times, prefix=f"decoder_native_resize({dims})")
131+
print()
132+
133+
center_x = (input_height - dims[0]) // 2
134+
center_y = (input_width - dims[1]) // 2
135+
times = bench(
136+
torchvision_crop,
137+
path,
138+
uniform_timestamps,
139+
dims,
140+
center_x,
141+
center_y,
142+
num_exp=args.num_exp,
143+
)
144+
report_stats(
145+
times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})"
146+
)
147+
148+
times = bench(
149+
decoder_native_crop,
150+
path,
151+
uniform_timestamps,
152+
dims,
153+
center_x,
154+
center_y,
155+
num_exp=args.num_exp,
156+
)
157+
report_stats(
158+
times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})"
159+
)
160+
print()
161+
162+
163+
if __name__ == "__main__":
164+
main()

0 commit comments

Comments
 (0)