team-10/venv/Lib/site-packages/transformers/video_utils.py
2025-08-02 02:00:33 +02:00

778 lines
30 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from collections.abc import Iterable
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import Callable, Optional, Union
from urllib.parse import urlparse
import numpy as np
import requests
from .image_transforms import PaddingMode, to_channel_dimension_format
from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
from .utils import (
is_av_available,
is_cv2_available,
is_decord_available,
is_numpy_array,
is_torch_available,
is_torch_tensor,
is_torchcodec_available,
is_torchvision_available,
is_vision_available,
is_yt_dlp_available,
logging,
requires_backends,
)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if is_torchvision_available():
from torchvision import io as torchvision_io
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
VideoInput = Union[
list["PIL.Image.Image"],
"np.ndarray",
"torch.Tensor",
list["np.ndarray"],
list["torch.Tensor"],
list[list["PIL.Image.Image"]],
list[list["np.ndarrray"]],
list[list["torch.Tensor"]],
] # noqa
@dataclass
class VideoMetadata:
total_num_frames: int
fps: float
duration: float
video_backend: str
def __getitem__(self, item):
return getattr(self, item)
def is_valid_video_frame(frame):
return isinstance(frame, PIL.Image.Image) or (
(is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
)
def is_valid_video(video):
if not isinstance(video, (list, tuple)):
return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
return all(is_valid_video_frame(frame) for frame in video)
def valid_videos(videos):
# If we have a list of videos, it could be either one video as list of frames or a batch
if isinstance(videos, (list, tuple)):
for video_or_frame in videos:
if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
return False
# If not a list, then we have a single 4D video or 5D batched tensor
elif not is_valid_video(videos) or videos.ndim == 5:
return False
return True
def is_batched_video(videos):
if isinstance(videos, (list, tuple)):
return is_valid_video(videos[0])
elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
return True
return False
def is_scaled_video(video: np.ndarray) -> bool:
"""
Checks to see whether the pixel values have already been rescaled to [0, 1].
"""
# It's possible the video has pixel values in [0, 255] but is of floating type
return np.min(video) >= 0 and np.max(video) <= 1
def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndarray", "torch.Tensor"]]:
"""
Given a batch of videos, converts each video to a 4D array. If video is already in array type,
it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
Args:
videos (`VideoInput`):
Video inputs to turn into a list of videos.
"""
if not isinstance(videos[0], (list, tuple)):
return videos
video_converted = []
for video in videos:
video = [np.array(frame) for frame in video]
video = np.stack(video)
video_converted.append(video)
return video_converted
def make_batched_videos(videos) -> list[Union["np.ndarray", "torch.Tensor"]]:
"""
Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
frames are converted to 4D arrays.
We assume that all inputs in the list are in the same format, based on the type of the first element.
Args:
videos (`VideoInput`):
Video inputs to turn into a list of videos.
"""
if not valid_videos:
raise ValueError(
f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
f" type {type(videos)}."
)
if is_batched_video(videos):
pass
elif is_valid_video(videos):
videos = [videos]
# only one frame passed, thus we unsqueeze time dim
elif is_valid_image(videos):
videos = [np.array(videos)[None, ...]]
# nested batch so we need to unflatten
elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]):
videos = [video for sublist in videos for video in sublist]
return convert_pil_frames_to_video(videos)
def get_video_size(video: np.ndarray, channel_dim: ChannelDimension = None) -> tuple[int, int]:
"""
Returns the (height, width) dimensions of the video.
Args:
video (`np.ndarray`):
The video to get the dimensions of.
channel_dim (`ChannelDimension`, *optional*):
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
Returns:
A tuple of the video's height and width.
"""
if channel_dim is None:
channel_dim = infer_channel_dimension_format(video)
if channel_dim == ChannelDimension.FIRST:
return video.shape[-2], video.shape[-1]
elif channel_dim == ChannelDimension.LAST:
return video.shape[-3], video.shape[-2]
else:
raise ValueError(f"Unsupported data format: {channel_dim}")
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
"""
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
when loading a video.
Args:
total_num_frames (`int`):
Total number of frames that a video has.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Returns:
np.ndarray: np array of frame indices that will be sampled.
"""
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
else:
indices = np.arange(0, total_num_frames).astype(int)
return indices
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
"""
A default sampling function that replicates the logic used in get_uniform_frame_indices,
while optionally handling `fps` if `num_frames` is not provided.
Args:
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly.
fps (`int` or `float`, *optional*):
Desired frames per second. Takes priority over num_frames if both are provided.
Returns:
`np.ndarray`: Array of frame indices to sample.
"""
total_num_frames = metadata.total_num_frames
video_fps = metadata.fps
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
)
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
else:
indices = np.arange(0, total_num_frames, dtype=int)
return indices
def read_video_opencv(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode a video using the OpenCV backend.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import cv2
requires_backends(read_video_opencv, ["cv2"])
import cv2
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = video.get(cv2.CAP_PROP_FPS)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if not success:
break
if index in indices:
height, width, channel = frame.shape
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= total_num_frames:
break
video.release()
metadata.frames_indices = indices
return np.stack(frames), metadata
def read_video_decord(
video_path: str,
sample_indices_fn: Optional[Callable] = None,
**kwargs,
):
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
requires_backends(read_video_decord, ["decord"])
from decord import VideoReader, cpu
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
frames = vr.get_batch(indices).asnumpy()
metadata.frames_indices = indices
return frames, metadata
def read_video_pyav(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with PyAV decoder.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import av
requires_backends(read_video_pyav, ["av"])
import av
container = av.open(video_path)
total_num_frames = container.streams.video[0].frames
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
frames = []
container.seek(0)
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= 0 and i in indices:
frames.append(frame)
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
metadata.frames_indices = indices
return video, metadata
def read_video_torchvision(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with torchvision decoder.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
warnings.warn(
"Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
"Please use `torchcodec` instead."
)
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="THWC",
)
video_fps = info["video_fps"]
total_num_frames = video.size(0)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="torchvision",
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
video = video[indices].contiguous().numpy()
metadata.frames_indices = indices
return video, metadata
def read_video_torchcodec(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with torchcodec decoder.
Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
requires_backends(read_video_torchcodec, ["torchcodec"])
from torchcodec.decoders import VideoDecoder
decoder = VideoDecoder(
video_path,
dimension_order="NHWC", # to be consistent with other decoders
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
)
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=decoder.metadata.duration_seconds,
video_backend="torchcodec",
)
indices = sample_indices_fn(metadata=metadata, **kwargs)
video = decoder.get_frames_at(indices=indices).data.contiguous()
metadata.frames_indices = indices
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"opencv": read_video_opencv,
"pyav": read_video_pyav,
"torchvision": read_video_torchvision,
"torchcodec": read_video_torchcodec,
}
def load_video(
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[Union[int, float]] = None,
backend: str = "pyav",
sample_indices_fn: Optional[Callable] = None,
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int` or `float`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
indices at which the video should be sampled. For example:
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
Returns:
tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
if fps is not None and num_frames is not None and sample_indices_fn is None:
raise ValueError(
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
)
# If user didn't pass a sampling function, create one on the fly with default logic
if sample_indices_fn is None:
def sample_indices_fn_func(metadata, **fn_kwargs):
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
sample_indices_fn = sample_indices_fn_func
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
requires_backends(load_video, ["yt_dlp"])
from yt_dlp import YoutubeDL
buffer = BytesIO()
with redirect_stdout(buffer), YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video).content)
elif os.path.isfile(video):
file_obj = video
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
file_obj = None
else:
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend in ["opencv", "torchvision"]:
raise ValueError(
"If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend"
)
if file_obj is None:
return video
if (
(not is_decord_available() and backend == "decord")
or (not is_av_available() and backend == "pyav")
or (not is_cv2_available() and backend == "opencv")
or (not is_torchvision_available() and backend == "torchvision")
or (not is_torchcodec_available() and backend == "torchcodec")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
return video, metadata
def convert_to_rgb(
video: np.array,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
Args:
video (`np.array`):
The video to convert.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output video. If unset, will use the inferred format from the input.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input video. If unset, will use the inferred format from the input.
"""
if not isinstance(video, np.ndarray):
raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
# np.array usually comes with ChannelDimension.LAST so leet's convert it
if input_data_format is None:
input_data_format = infer_channel_dimension_format(video)
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
# 3 channels for RGB already
if video.shape[-3] == 3:
return video
# Grayscale video so we repeat it 3 times for each channel
if video.shape[-3] == 1:
return video.repeat(3, -3)
if not (video[..., 3, :, :] < 255).any():
return video
# There is a transparency layer, blend it with a white background.
# Calculate the alpha proportion for blending.
alpha = video[..., 3, :, :] / 255.0
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
return video
def pad(
video: np.ndarray,
padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pads the `video` with the specified (height, width) `padding` and `mode`.
Args:
video (`np.ndarray`):
The video to pad.
padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output video. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
If unset, will use same as the input video.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input video. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
If unset, will use the inferred format of the input video.
Returns:
`np.ndarray`: The padded video.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(video)
def _expand_for_data_format(values):
"""
Convert values to be in the format expected by np.pad based on the data format.
"""
if isinstance(values, (int, float)):
values = ((values, values), (values, values))
elif isinstance(values, tuple) and len(values) == 1:
values = ((values[0], values[0]), (values[0], values[0]))
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
values = (values, values)
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
values = values
else:
raise ValueError(f"Unsupported format: {values}")
# add 0 for channel dimension
values = (
((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
)
# Add additional padding if there's a batch dimension
values = (0, *values) if video.ndim == 5 else values
return values
padding_map = {
PaddingMode.CONSTANT: "constant",
PaddingMode.REFLECT: "reflect",
PaddingMode.REPLICATE: "replicate",
PaddingMode.SYMMETRIC: "symmetric",
}
padding = _expand_for_data_format(padding)
pad_kwargs = {}
if mode not in padding_map:
raise ValueError(f"Invalid padding mode: {mode}")
elif mode == PaddingMode.CONSTANT:
pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
return video
def group_videos_by_shape(
videos: list["torch.Tensor"],
) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
"""
Groups videos by shape.
Returns a dictionary with the shape as key and a list of videos with that shape as value,
and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
"""
grouped_videos = {}
grouped_videos_index = {}
for i, video in enumerate(videos):
shape = video.shape[-2::]
num_frames = video.shape[-4] # video format BTCHW
shape = (num_frames, *shape)
if shape not in grouped_videos:
grouped_videos[shape] = []
grouped_videos[shape].append(video)
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
# stack videos with the same size and number of frames
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
return grouped_videos, grouped_videos_index
def reorder_videos(
processed_videos: dict[tuple[int, int], "torch.Tensor"], grouped_videos_index: dict[int, tuple[int, int]]
) -> list["torch.Tensor"]:
"""
Reconstructs a list of videos in the original order.
"""
return [
processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
for i in range(len(grouped_videos_index))
]