@@ -43,9 +43,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4343 m.def (
4444 " _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
4545 m.def (
46- " _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
46+ " _add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , str transform_specs= \" \" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
4747 m.def (
48- " add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()" );
48+ " add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , str transform_specs= \" \" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()" );
4949 m.def (
5050 " add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()" );
5151 m.def (" seek_to_pts(Tensor(a!) decoder, float seconds) -> ()" );
@@ -183,6 +183,69 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
183183 }
184184}
185185
186+ int checkedToPositiveInt (const std::string& str) {
187+ int ret = 0 ;
188+ try {
189+ ret = std::stoi (str);
190+ } catch (const std::invalid_argument&) {
191+ TORCH_CHECK (false , " String cannot be converted to an int:" + str);
192+ } catch (const std::out_of_range&) {
193+ TORCH_CHECK (false , " String would become integer out of range:" + str);
194+ }
195+ TORCH_CHECK (ret > 0 , " String must be a positive integer:" + str);
196+ return ret;
197+ }
198+
199+ // Resize transform specs take the form:
200+ //
201+ // "resize, <height>, <width>"
202+ //
203+ // Where "resize" is the string literal and <height> and <width> are positive
204+ // integers.
205+ Transform* makeResizeTransform (
206+ const std::vector<std::string>& resizeTransformSpec) {
207+ TORCH_CHECK (
208+ resizeTransformSpec.size () == 3 ,
209+ " resizeTransformSpec must have 3 elements including its name" );
210+ int height = checkedToPositiveInt (resizeTransformSpec[1 ]);
211+ int width = checkedToPositiveInt (resizeTransformSpec[2 ]);
212+ return new ResizeTransform (FrameDims (height, width));
213+ }
214+
215+ std::vector<std::string> split (const std::string& str, char delimiter) {
216+ std::vector<std::string> tokens;
217+ std::string token;
218+ std::istringstream tokenStream (str);
219+ while (std::getline (tokenStream, token, delimiter)) {
220+ tokens.push_back (token);
221+ }
222+ return tokens;
223+ }
224+
225+ // The transformSpecsRaw string is always in the format:
226+ //
227+ // "name1, param1, param2, ...; name2, param1, param2, ...; ..."
228+ //
229+ // Where "nameX" is the name of the transform, and "paramX" are the parameters.
230+ std::vector<Transform*> makeTransforms (const std::string& transformSpecsRaw) {
231+ std::vector<Transform*> transforms;
232+ std::vector<std::string> transformSpecs = split (transformSpecsRaw, ' ;' );
233+ for (const std::string& transformSpecRaw : transformSpecs) {
234+ std::vector<std::string> transformSpec = split (transformSpecRaw, ' ,' );
235+ TORCH_CHECK (
236+ transformSpec.size () >= 1 ,
237+ " Invalid transform spec: " + transformSpecRaw);
238+
239+ auto name = transformSpec[0 ];
240+ if (name == " resize" ) {
241+ transforms.push_back (makeResizeTransform (transformSpec));
242+ } else {
243+ TORCH_CHECK (false , " Invalid transform name: " + name);
244+ }
245+ }
246+ return transforms;
247+ }
248+
186249} // namespace
187250
188251// ==============================
@@ -252,36 +315,18 @@ at::Tensor _create_from_file_like(
252315
253316void _add_video_stream (
254317 at::Tensor& decoder,
255- std::optional<int64_t > width = std::nullopt ,
256- std::optional<int64_t > height = std::nullopt ,
257318 std::optional<int64_t > num_threads = std::nullopt ,
258319 std::optional<std::string_view> dimension_order = std::nullopt ,
259320 std::optional<int64_t > stream_index = std::nullopt ,
260321 std::string_view device = " cpu" ,
261322 std::string_view device_variant = " default" ,
323+ std::string_view transform_specs = " " ,
262324 std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
263325 custom_frame_mappings = std::nullopt ,
264326 std::optional<std::string_view> color_conversion_library = std::nullopt ) {
265327 VideoStreamOptions videoStreamOptions;
266328 videoStreamOptions.ffmpegThreadCount = num_threads;
267329
268- // TODO: Eliminate this temporary bridge code. This exists because we have
269- // not yet exposed the transforms API on the Python side. We also want
270- // to remove the `width` and `height` arguments from the Python API.
271- //
272- // TEMPORARY BRIDGE CODE START
273- TORCH_CHECK (
274- width.has_value () == height.has_value (),
275- " width and height must both be set or unset." );
276- std::vector<Transform*> transforms;
277- if (width.has_value ()) {
278- transforms.push_back (
279- new ResizeTransform (FrameDims (height.value (), width.value ())));
280- width.reset ();
281- height.reset ();
282- }
283- // TEMPORARY BRIDGE CODE END
284-
285330 if (dimension_order.has_value ()) {
286331 std::string stdDimensionOrder{dimension_order.value ()};
287332 TORCH_CHECK (stdDimensionOrder == " NHWC" || stdDimensionOrder == " NCHW" );
@@ -309,6 +354,9 @@ void _add_video_stream(
309354 videoStreamOptions.device = torch::Device (std::string (device));
310355 videoStreamOptions.deviceVariant = device_variant;
311356
357+ std::vector<Transform*> transforms =
358+ makeTransforms (std::string (transform_specs));
359+
312360 std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
313361 custom_frame_mappings.has_value ()
314362 ? std::make_optional (makeFrameMappings (custom_frame_mappings.value ()))
@@ -324,24 +372,22 @@ void _add_video_stream(
324372// Add a new video stream at `stream_index` using the provided options.
325373void add_video_stream (
326374 at::Tensor& decoder,
327- std::optional<int64_t > width = std::nullopt ,
328- std::optional<int64_t > height = std::nullopt ,
329375 std::optional<int64_t > num_threads = std::nullopt ,
330376 std::optional<std::string_view> dimension_order = std::nullopt ,
331377 std::optional<int64_t > stream_index = std::nullopt ,
332378 std::string_view device = " cpu" ,
333379 std::string_view device_variant = " default" ,
380+ std::string_view transform_specs = " " ,
334381 const std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>&
335382 custom_frame_mappings = std::nullopt ) {
336383 _add_video_stream (
337384 decoder,
338- width,
339- height,
340385 num_threads,
341386 dimension_order,
342387 stream_index,
343388 device,
344389 device_variant,
390+ transform_specs,
345391 custom_frame_mappings);
346392}
347393
0 commit comments