176 lines
6.1 KiB
Python
176 lines
6.1 KiB
Python
![]() |
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import io
|
||
|
import os
|
||
|
import re
|
||
|
from pathlib import Path
|
||
|
|
||
|
from streamlit import runtime
|
||
|
from streamlit.runtime import caching
|
||
|
from streamlit.util import calc_md5
|
||
|
|
||
|
# Regular expression to match the SRT timestamp format
|
||
|
# It matches the
|
||
|
# "hours:minutes:seconds,milliseconds --> hours:minutes:seconds,milliseconds" format
|
||
|
SRT_VALIDATION_REGEX = r"\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}"
|
||
|
|
||
|
SRT_CONVERSION_REGEX = r"(\d{2}:\d{2}:\d{2}),(\d{3})"
|
||
|
|
||
|
SUBTITLE_ALLOWED_FORMATS = (".srt", ".vtt")
|
||
|
|
||
|
|
||
|
def _is_srt(stream: str | io.BytesIO | bytes) -> bool:
|
||
|
# Handle raw bytes
|
||
|
if isinstance(stream, bytes):
|
||
|
stream = io.BytesIO(stream)
|
||
|
|
||
|
# Convert str to io.BytesIO if 'stream' is a string
|
||
|
if isinstance(stream, str):
|
||
|
stream = io.BytesIO(stream.encode("utf-8"))
|
||
|
|
||
|
# Set the stream position to the beginning in case it's been moved
|
||
|
stream.seek(0)
|
||
|
|
||
|
# Read enough bytes to reliably check for SRT patterns
|
||
|
# This might be adjusted, but 33 bytes should be enough to read the first numeric
|
||
|
# line, the full timestamp line, and a bit of the next line
|
||
|
header = stream.read(33)
|
||
|
|
||
|
try:
|
||
|
header_str = header.decode("utf-8").strip() # Decode and strip whitespace
|
||
|
except UnicodeDecodeError:
|
||
|
# If it's not valid utf-8, it's probably not a valid SRT file
|
||
|
return False
|
||
|
|
||
|
# Split the header into lines and process them
|
||
|
lines = header_str.split("\n")
|
||
|
|
||
|
# Check for the pattern of an SRT file: digit(s), newline, timestamp
|
||
|
if len(lines) >= 2 and lines[0].isdigit():
|
||
|
match = re.search(SRT_VALIDATION_REGEX, lines[1])
|
||
|
if match:
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _srt_to_vtt(srt_data: str | bytes) -> bytes:
|
||
|
"""
|
||
|
Convert subtitles from SubRip (.srt) format to WebVTT (.vtt) format.
|
||
|
This function accepts the content of the .srt file either as a string
|
||
|
or as a BytesIO stream.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
srt_data : str or bytes
|
||
|
The content of the .srt file as a string or a bytes stream.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
bytes
|
||
|
The content converted into .vtt format.
|
||
|
"""
|
||
|
|
||
|
# If the input is a bytes stream, convert it to a string
|
||
|
if isinstance(srt_data, bytes):
|
||
|
# Decode the bytes to a UTF-8 string
|
||
|
try:
|
||
|
srt_data = srt_data.decode("utf-8")
|
||
|
except UnicodeDecodeError as e:
|
||
|
raise ValueError("Could not decode the input stream as UTF-8.") from e
|
||
|
if not isinstance(srt_data, str):
|
||
|
# If it's not a string by this point, something is wrong.
|
||
|
raise TypeError(
|
||
|
f"Input must be a string or a bytes stream, not {type(srt_data)}."
|
||
|
)
|
||
|
|
||
|
# Replace SubRip timing with WebVTT timing
|
||
|
vtt_data = re.sub(SRT_CONVERSION_REGEX, r"\1.\2", srt_data)
|
||
|
|
||
|
# Add WebVTT file header
|
||
|
vtt_content = "WEBVTT\n\n" + vtt_data
|
||
|
# Convert the vtt content to bytes
|
||
|
return vtt_content.strip().encode("utf-8")
|
||
|
|
||
|
|
||
|
def _handle_string_or_path_data(data_or_path: str | Path) -> bytes:
|
||
|
"""Handles string data, either as a file path or raw content."""
|
||
|
if os.path.isfile(data_or_path):
|
||
|
path = Path(data_or_path)
|
||
|
file_extension = path.suffix.lower()
|
||
|
|
||
|
if file_extension not in SUBTITLE_ALLOWED_FORMATS:
|
||
|
raise ValueError(
|
||
|
f"Incorrect subtitle format {file_extension}. Subtitles must be in "
|
||
|
f"one of the following formats: {', '.join(SUBTITLE_ALLOWED_FORMATS)}"
|
||
|
)
|
||
|
with open(data_or_path, "rb") as file:
|
||
|
content = file.read()
|
||
|
return _srt_to_vtt(content) if file_extension == ".srt" else content
|
||
|
if isinstance(data_or_path, Path):
|
||
|
raise ValueError(f"File {data_or_path} does not exist.") # noqa: TRY004
|
||
|
|
||
|
content_string = data_or_path.strip()
|
||
|
|
||
|
if content_string.startswith("WEBVTT") or content_string == "":
|
||
|
return content_string.encode("utf-8")
|
||
|
if _is_srt(content_string):
|
||
|
return _srt_to_vtt(content_string)
|
||
|
raise ValueError("The provided string neither matches valid VTT nor SRT format.")
|
||
|
|
||
|
|
||
|
def _handle_stream_data(stream: io.BytesIO) -> bytes:
|
||
|
"""Handles io.BytesIO data, converting SRT to VTT content if needed."""
|
||
|
stream.seek(0)
|
||
|
stream_data = stream.getvalue()
|
||
|
return _srt_to_vtt(stream_data) if _is_srt(stream) else stream_data
|
||
|
|
||
|
|
||
|
def _handle_bytes_data(data: bytes) -> bytes:
|
||
|
"""Handles io.BytesIO data, converting SRT to VTT content if needed."""
|
||
|
return _srt_to_vtt(data) if _is_srt(data) else data
|
||
|
|
||
|
|
||
|
def process_subtitle_data(
|
||
|
coordinates: str,
|
||
|
data: str | bytes | Path | io.BytesIO,
|
||
|
label: str,
|
||
|
) -> str:
|
||
|
# Determine the type of data and process accordingly
|
||
|
if isinstance(data, (str, Path)):
|
||
|
subtitle_data = _handle_string_or_path_data(data)
|
||
|
elif isinstance(data, io.BytesIO):
|
||
|
subtitle_data = _handle_stream_data(data)
|
||
|
elif isinstance(data, bytes):
|
||
|
subtitle_data = _handle_bytes_data(data)
|
||
|
else:
|
||
|
raise TypeError(f"Invalid binary data format for subtitle: {type(data)}.")
|
||
|
|
||
|
if runtime.exists():
|
||
|
filename = calc_md5(label.encode())
|
||
|
# Save the processed data and return the file URL
|
||
|
file_url = runtime.get_instance().media_file_mgr.add(
|
||
|
path_or_data=subtitle_data,
|
||
|
mimetype="text/vtt",
|
||
|
coordinates=coordinates,
|
||
|
file_name=f"{filename}.vtt",
|
||
|
)
|
||
|
caching.save_media_data(subtitle_data, "text/vtt", coordinates)
|
||
|
return file_url
|
||
|
# When running in "raw mode", we can't access the MediaFileManager.
|
||
|
return ""
|