team-10/env/Lib/site-packages/streamlit/web/server/media_file_handler.py

149 lines
5.4 KiB
Python
Raw Normal View History

2025-08-02 07:34:44 +02:00
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, cast
from urllib.parse import quote
import tornado.web
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
from streamlit.runtime.memory_media_file_storage import (
MemoryMediaFileStorage,
get_extension_for_mimetype,
)
from streamlit.web.server import allow_all_cross_origin_requests, is_allowed_origin
_LOGGER = get_logger(__name__)
class MediaFileHandler(tornado.web.StaticFileHandler):
_storage: MemoryMediaFileStorage
@classmethod
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
"""Set the MemoryMediaFileStorage object used by instances of this
handler. Must be called on server startup.
"""
# This is a class method, rather than an instance method, because
# `get_content()` is a class method and needs to access the storage
# instance.
cls._storage = storage
def set_default_headers(self) -> None:
if allow_all_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
elif is_allowed_origin(origin := self.request.headers.get("Origin")):
self.set_header("Access-Control-Allow-Origin", cast("str", origin))
def set_extra_headers(self, path: str) -> None:
"""Add Content-Disposition header for downloadable files.
Set header value to "attachment" indicating that file should be saved
locally instead of displaying inline in browser.
We also set filename to specify the filename for downloaded files.
Used for serving downloadable files, like files stored via the
`st.download_button` widget.
"""
media_file = self._storage.get_file(path)
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
filename = media_file.filename
if not filename:
filename = f"streamlit_download{get_extension_for_mimetype(media_file.mimetype)}"
try:
# Check that the value can be encoded in latin1. Latin1 is
# the default encoding for headers.
filename.encode("latin1")
file_expr = f'filename="{filename}"'
except UnicodeEncodeError:
# RFC5987 syntax.
# See: https://datatracker.ietf.org/doc/html/rfc5987
file_expr = f"filename*=utf-8''{quote(filename)}"
self.set_header("Content-Disposition", f"attachment; {file_expr}")
# Overriding StaticFileHandler to use the MediaFileManager
#
# From the Tornado docs:
# To replace all interaction with the filesystem (e.g. to serve
# static content from a database), override `get_content`,
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
# `validate_absolute_path`.
def validate_absolute_path(
self,
root: str, # noqa: ARG002
absolute_path: str,
) -> str:
try:
self._storage.get_file(absolute_path)
except MediaFileStorageError:
_LOGGER.exception("MediaFileHandler: Missing file %s", absolute_path)
raise tornado.web.HTTPError(404, "not found")
return absolute_path
def get_content_size(self) -> int:
abspath = self.absolute_path
if abspath is None:
return 0
media_file = self._storage.get_file(abspath)
return media_file.content_size
def get_modified_time(self) -> None:
# We do not track last modified time, but this can be improved to
# allow caching among files in the MediaFileManager
return None
@classmethod
def get_absolute_path(cls, root: str, path: str) -> str: # noqa: ARG003
# All files are stored in memory, so the absolute path is just the
# path itself. In the MediaFileHandler, it's just the filename
return path
@classmethod
def get_content(
cls, abspath: str, start: int | None = None, end: int | None = None
) -> Any:
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
try:
# abspath is the hash as used `get_absolute_path`
media_file = cls._storage.get_file(abspath)
except Exception:
_LOGGER.exception("MediaFileHandler: Missing file %s", abspath)
return None
_LOGGER.debug(
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
)
# If there is no start and end, just return the full content
if start is None and end is None:
return media_file.content
if start is None:
start = 0
if end is None:
end = len(media_file.content)
# content is bytes that work just by slicing supplied by start and end
return media_file.content[start:end]