team-10/venv/Lib/site-packages/torchvision/io/_video_opt.py
2025-08-02 02:00:33 +02:00

521 lines
20 KiB
Python

import math
import warnings
from fractions import Fraction
from typing import Dict, List, Optional, Tuple, Union
import torch
from ..extension import _load_library
from ._video_deprecation_warning import _raise_video_deprecation_warning
try:
_load_library("video_reader")
_HAS_CPU_VIDEO_DECODER = True
except (ImportError, OSError):
_HAS_CPU_VIDEO_DECODER = False
_HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC
default_timebase = Fraction(0, 1)
# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
class Timebase:
__annotations__ = {"numerator": int, "denominator": int}
__slots__ = ["numerator", "denominator"]
def __init__(
self,
numerator: int,
denominator: int,
) -> None:
self.numerator = numerator
self.denominator = denominator
class VideoMetaData:
__annotations__ = {
"has_video": bool,
"video_timebase": Timebase,
"video_duration": float,
"video_fps": float,
"has_audio": bool,
"audio_timebase": Timebase,
"audio_duration": float,
"audio_sample_rate": float,
}
__slots__ = [
"has_video",
"video_timebase",
"video_duration",
"video_fps",
"has_audio",
"audio_timebase",
"audio_duration",
"audio_sample_rate",
]
def __init__(self) -> None:
self.has_video = False
self.video_timebase = Timebase(0, 1)
self.video_duration = 0.0
self.video_fps = 0.0
self.has_audio = False
self.audio_timebase = Timebase(0, 1)
self.audio_duration = 0.0
self.audio_sample_rate = 0.0
def _validate_pts(pts_range: Tuple[int, int]) -> None:
if pts_range[0] > pts_range[1] > 0:
raise ValueError(
f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
)
def _fill_info(
vtimebase: torch.Tensor,
vfps: torch.Tensor,
vduration: torch.Tensor,
atimebase: torch.Tensor,
asample_rate: torch.Tensor,
aduration: torch.Tensor,
) -> VideoMetaData:
"""
Build update VideoMetaData struct with info about the video
"""
meta = VideoMetaData()
if vtimebase.numel() > 0:
meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
timebase = vtimebase[0].item() / float(vtimebase[1].item())
if vduration.numel() > 0:
meta.has_video = True
meta.video_duration = float(vduration.item()) * timebase
if vfps.numel() > 0:
meta.video_fps = float(vfps.item())
if atimebase.numel() > 0:
meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
timebase = atimebase[0].item() / float(atimebase[1].item())
if aduration.numel() > 0:
meta.has_audio = True
meta.audio_duration = float(aduration.item()) * timebase
if asample_rate.numel() > 0:
meta.audio_sample_rate = float(asample_rate.item())
return meta
def _align_audio_frames(
aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
) -> torch.Tensor:
start, end = aframe_pts[0], aframe_pts[-1]
num_samples = aframes.size(0)
step_per_aframe = float(end - start + 1) / float(num_samples)
s_idx = 0
e_idx = num_samples
if start < audio_pts_range[0]:
s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
return aframes[s_idx:e_idx, :]
def _read_video_from_file(
filename: str,
seek_frame_margin: float = 0.25,
read_video_stream: bool = True,
video_width: int = 0,
video_height: int = 0,
video_min_dimension: int = 0,
video_max_dimension: int = 0,
video_pts_range: Tuple[int, int] = (0, -1),
video_timebase: Fraction = default_timebase,
read_audio_stream: bool = True,
audio_samples: int = 0,
audio_channels: int = 0,
audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase: Fraction = default_timebase,
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
the size of decoded frames:
- When video_width = 0, video_height = 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the original frame resolution
- When video_width = 0, video_height = 0, video_min_dimension != 0,
and video_max_dimension = 0, keep the aspect ratio and resize the
frame so that shorter edge size is video_min_dimension
- When video_width = 0, video_height = 0, video_min_dimension = 0,
and video_max_dimension != 0, keep the aspect ratio and resize
the frame so that longer edge size is video_max_dimension
- When video_width = 0, video_height = 0, video_min_dimension != 0,
and video_max_dimension != 0, resize the frame so that shorter
edge size is video_min_dimension, and longer edge size is
video_max_dimension. The aspect ratio may not be preserved
- When video_width = 0, video_height != 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the aspect ratio and resize
the frame so that frame video_height is $video_height
- When video_width != 0, video_height == 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the aspect ratio and resize
the frame so that frame video_width is $video_width
- When video_width != 0, video_height != 0, video_min_dimension = 0,
and video_max_dimension = 0, resize the frame so that frame
video_width and video_height are set to $video_width and
$video_height, respectively
video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
audio_samples (int, optional): audio sampling rate
audio_channels (int optional): audio channels
audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
Returns
vframes (Tensor[T, H, W, C]): the `T` video frames
aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
`K` is the number of audio_channels
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
and audio_fps (int)
"""
_raise_video_deprecation_warning()
_validate_pts(video_pts_range)
_validate_pts(audio_pts_range)
result = torch.ops.video_reader.read_video_from_file(
filename,
seek_frame_margin,
0, # getPtsOnly
read_video_stream,
video_width,
video_height,
video_min_dimension,
video_max_dimension,
video_pts_range[0],
video_pts_range[1],
video_timebase.numerator,
video_timebase.denominator,
read_audio_stream,
audio_samples,
audio_channels,
audio_pts_range[0],
audio_pts_range[1],
audio_timebase.numerator,
audio_timebase.denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0:
# when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes, info
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
"""
Decode all video- and audio frames in the video. Only pts
(presentation timestamp) is returned. The actual frame pixel data is not
copied. Thus, it is much faster than read_video(...)
"""
result = torch.ops.video_reader.read_video_from_file(
filename,
0, # seek_frame_margin
1, # getPtsOnly
1, # read_video_stream
0, # video_width
0, # video_height
0, # video_min_dimension
0, # video_max_dimension
0, # video_start_pts
-1, # video_end_pts
0, # video_timebase_num
1, # video_timebase_den
1, # read_audio_stream
0, # audio_samples
0, # audio_channels
0, # audio_start_pts
-1, # audio_end_pts
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info
def _probe_video_from_file(filename: str) -> VideoMetaData:
"""
Probe a video file and return VideoMetaData with info about the video
"""
_raise_video_deprecation_warning()
result = torch.ops.video_reader.probe_video_from_file(filename)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
def _read_video_from_memory(
video_data: torch.Tensor,
seek_frame_margin: float = 0.25,
read_video_stream: int = 1,
video_width: int = 0,
video_height: int = 0,
video_min_dimension: int = 0,
video_max_dimension: int = 0,
video_pts_range: Tuple[int, int] = (0, -1),
video_timebase_numerator: int = 0,
video_timebase_denominator: int = 1,
read_audio_stream: int = 1,
audio_samples: int = 0,
audio_channels: int = 0,
audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase_numerator: int = 0,
audio_timebase_denominator: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reads a video from memory, returning both the video frames as the audio frames
This function is torchscriptable.
Args:
video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
compressed video content stored in either 1) torch.Tensor 2) python bytes
seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
the size of decoded frames:
- When video_width = 0, video_height = 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the original frame resolution
- When video_width = 0, video_height = 0, video_min_dimension != 0,
and video_max_dimension = 0, keep the aspect ratio and resize the
frame so that shorter edge size is video_min_dimension
- When video_width = 0, video_height = 0, video_min_dimension = 0,
and video_max_dimension != 0, keep the aspect ratio and resize
the frame so that longer edge size is video_max_dimension
- When video_width = 0, video_height = 0, video_min_dimension != 0,
and video_max_dimension != 0, resize the frame so that shorter
edge size is video_min_dimension, and longer edge size is
video_max_dimension. The aspect ratio may not be preserved
- When video_width = 0, video_height != 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the aspect ratio and resize
the frame so that frame video_height is $video_height
- When video_width != 0, video_height == 0, video_min_dimension = 0,
and video_max_dimension = 0, keep the aspect ratio and resize
the frame so that frame video_width is $video_width
- When video_width != 0, video_height != 0, video_min_dimension = 0,
and video_max_dimension = 0, resize the frame so that frame
video_width and video_height are set to $video_width and
$video_height, respectively
video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
video_timebase_numerator / video_timebase_denominator (float, optional): a rational
number which denotes timebase in video stream
read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
audio_samples (int, optional): audio sampling rate
audio_channels (int optional): audio audio_channels
audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
audio_timebase_numerator / audio_timebase_denominator (float, optional):
a rational number which denotes time base in audio stream
Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
`K` is the number of channels
"""
_raise_video_deprecation_warning()
_validate_pts(video_pts_range)
_validate_pts(audio_pts_range)
if not isinstance(video_data, torch.Tensor):
with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.read_video_from_memory(
video_data,
seek_frame_margin,
0, # getPtsOnly
read_video_stream,
video_width,
video_height,
video_min_dimension,
video_max_dimension,
video_pts_range[0],
video_pts_range[1],
video_timebase_numerator,
video_timebase_denominator,
read_audio_stream,
audio_samples,
audio_channels,
audio_pts_range[0],
audio_pts_range[1],
audio_timebase_numerator,
audio_timebase_denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
if aframes.numel() > 0:
# when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes
def _read_video_timestamps_from_memory(
video_data: torch.Tensor,
) -> Tuple[List[int], List[int], VideoMetaData]:
"""
Decode all frames in the video. Only pts (presentation timestamp) is returned.
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
is much faster than read_video(...)
"""
if not isinstance(video_data, torch.Tensor):
with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.read_video_from_memory(
video_data,
0, # seek_frame_margin
1, # getPtsOnly
1, # read_video_stream
0, # video_width
0, # video_height
0, # video_min_dimension
0, # video_max_dimension
0, # video_start_pts
-1, # video_end_pts
0, # video_timebase_num
1, # video_timebase_den
1, # read_audio_stream
0, # audio_samples
0, # audio_channels
0, # audio_start_pts
-1, # audio_end_pts
0, # audio_timebase_num
1, # audio_timebase_den
)
_raise_video_deprecation_warning()
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info
def _probe_video_from_memory(
video_data: torch.Tensor,
) -> VideoMetaData:
"""
Probe a video in memory and return VideoMetaData with info about the video
This function is torchscriptable
"""
_raise_video_deprecation_warning()
if not isinstance(video_data, torch.Tensor):
with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.probe_video_from_memory(video_data)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
def _read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
_raise_video_deprecation_warning()
if end_pts is None:
end_pts = float("inf")
if pts_unit == "pts":
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
info = _probe_video_from_file(filename)
has_video = info.has_video
has_audio = info.has_audio
def get_pts(time_base):
start_offset = start_pts
end_offset = end_pts
if pts_unit == "sec":
start_offset = int(math.floor(start_pts * (1 / time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base)))
if end_offset == float("inf"):
end_offset = -1
return start_offset, end_offset
video_pts_range = (0, -1)
video_timebase = default_timebase
if has_video:
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
video_pts_range = get_pts(video_timebase)
audio_pts_range = (0, -1)
audio_timebase = default_timebase
if has_audio:
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
audio_pts_range = get_pts(audio_timebase)
vframes, aframes, info = _read_video_from_file(
filename,
read_video_stream=True,
video_pts_range=video_pts_range,
video_timebase=video_timebase,
read_audio_stream=True,
audio_pts_range=audio_pts_range,
audio_timebase=audio_timebase,
)
_info = {}
if has_video:
_info["video_fps"] = info.video_fps
if has_audio:
_info["audio_fps"] = info.audio_sample_rate
return vframes, aframes, _info
def _read_video_timestamps(
filename: str, pts_unit: str = "pts"
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
_raise_video_deprecation_warning()
if pts_unit == "pts":
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
pts: Union[List[int], List[Fraction]]
pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == "sec":
video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
pts = [x * video_time_base for x in pts]
video_fps = info.video_fps if info.has_video else None
return pts, video_fps