# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides global MediaFileManager object as `media_file_manager`.""" from __future__ import annotations import collections import threading from typing import Final from streamlit.logger import get_logger from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorage _LOGGER: Final = get_logger(__name__) def _get_session_id() -> str: """Get the active AppSession's session_id.""" from streamlit.runtime.scriptrunner_utils.script_run_context import ( get_script_run_ctx, ) ctx = get_script_run_ctx() if ctx is None: # This is only None when running "python myscript.py" rather than # "streamlit run myscript.py". In which case the session ID doesn't # matter and can just be a constant, as there's only ever "session". return "dontcare" return ctx.session_id class MediaFileMetadata: """Metadata that the MediaFileManager needs for each file it manages.""" def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA) -> None: self._kind = kind self._is_marked_for_delete = False @property def kind(self) -> MediaFileKind: return self._kind @property def is_marked_for_delete(self) -> bool: return self._is_marked_for_delete def mark_for_delete(self) -> None: self._is_marked_for_delete = True class MediaFileManager: """In-memory file manager for MediaFile objects. This keeps track of: - Which files exist, and what their IDs are. This is important so we can serve files by ID -- that's the whole point of this class! - Which files are being used by which AppSession (by ID). This is important so we can remove files from memory when no more sessions need them. - The exact location in the app where each file is being used (i.e. the file's "coordinates"). This is is important so we can mark a file as "not being used by a certain session" if it gets replaced by another file at the same coordinates. For example, when doing an animation where the same image is constantly replace with new frames. (This doesn't solve the case where the file's coordinates keep changing for some reason, though! e.g. if new elements keep being prepended to the app. Unlikely to happen, but we should address it at some point.) """ def __init__(self, storage: MediaFileStorage) -> None: self._storage = storage # Dict of [file_id -> MediaFileMetadata] self._file_metadata: dict[str, MediaFileMetadata] = {} # Dict[session ID][coordinates] -> file_id. self._files_by_session_and_coord: dict[str, dict[str, str]] = ( collections.defaultdict(dict) ) # MediaFileManager is used from multiple threads, so all operations # need to be protected with a Lock. (This is not an RLock, which # means taking it multiple times from the same thread will deadlock.) self._lock = threading.Lock() def _get_inactive_file_ids(self) -> set[str]: """Compute the set of files that are stored in the manager, but are not referenced by any active session. These are files that can be safely deleted. Thread safety: callers must hold `self._lock`. """ # Get the set of all our file IDs. file_ids = set(self._file_metadata.keys()) # Subtract all IDs that are in use by each session for session_file_ids_by_coord in self._files_by_session_and_coord.values(): file_ids.difference_update(session_file_ids_by_coord.values()) return file_ids def remove_orphaned_files(self) -> None: """Remove all files that are no longer referenced by any active session. Safe to call from any thread. """ _LOGGER.debug("Removing orphaned files...") with self._lock: for file_id in self._get_inactive_file_ids(): file = self._file_metadata[file_id] if file.kind == MediaFileKind.MEDIA: self._delete_file(file_id) elif file.kind == MediaFileKind.DOWNLOADABLE: if file.is_marked_for_delete: self._delete_file(file_id) else: file.mark_for_delete() def _delete_file(self, file_id: str) -> None: """Delete the given file from storage, and remove its metadata from self._files_by_id. Thread safety: callers must hold `self._lock`. """ _LOGGER.debug("Deleting File: %s", file_id) self._storage.delete_file(file_id) del self._file_metadata[file_id] def clear_session_refs(self, session_id: str | None = None) -> None: """Remove the given session's file references. (This does not remove any files from the manager - you must call `remove_orphaned_files` for that.) Should be called whenever ScriptRunner completes and when a session ends. Safe to call from any thread. """ if session_id is None: session_id = _get_session_id() _LOGGER.debug("Disconnecting files for session with ID %s", session_id) with self._lock: if session_id in self._files_by_session_and_coord: del self._files_by_session_and_coord[session_id] _LOGGER.debug( "Sessions still active: %r", self._files_by_session_and_coord.keys() ) _LOGGER.debug( "Files: %s; Sessions with files: %s", len(self._file_metadata), len(self._files_by_session_and_coord), ) def add( self, path_or_data: bytes | str, mimetype: str, coordinates: str, file_name: str | None = None, is_for_static_download: bool = False, ) -> str: """Add a new MediaFile with the given parameters and return its URL. If an identical file already exists, return the existing URL and registers the current session as a user. Safe to call from any thread. Parameters ---------- path_or_data : bytes or str If bytes: the media file's raw data. If str: the name of a file to load from disk. mimetype : str The mime type for the file. E.g. "audio/mpeg". This string will be used in the "Content-Type" header when the file is served over HTTP. coordinates : str Unique string identifying an element's location. Prevents memory leak of "forgotten" file IDs when element media is being replaced-in-place (e.g. an st.image stream). coordinates should be of the form: "1.(3.-14).5" file_name : str or None Optional file_name. Used to set the filename in the response header. is_for_static_download: bool Indicate that data stored for downloading as a file, not as a media for rendering at page. [default: False] Returns ------- str The url that the frontend can use to fetch the media. Raises ------ If a filename is passed, any Exception raised when trying to read the file will be re-raised. """ session_id = _get_session_id() with self._lock: kind = ( MediaFileKind.DOWNLOADABLE if is_for_static_download else MediaFileKind.MEDIA ) file_id = self._storage.load_and_get_id( path_or_data, mimetype, kind, file_name ) metadata = MediaFileMetadata(kind=kind) self._file_metadata[file_id] = metadata self._files_by_session_and_coord[session_id][coordinates] = file_id return self._storage.get_url(file_id)