# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """@st.cache_resource implementation.""" from __future__ import annotations import math import threading from typing import ( TYPE_CHECKING, Any, Callable, Final, TypeVar, overload, ) from cachetools import TTLCache from typing_extensions import ParamSpec, TypeAlias import streamlit as st from streamlit.logger import get_logger from streamlit.runtime.caching import cache_utils from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.cache_utils import ( Cache, CachedFunc, CachedFuncInfo, make_cached_func_wrapper, ) from streamlit.runtime.caching.cached_message_replay import ( CachedMessageReplayContext, CachedResult, MsgData, show_widget_replay_deprecation, ) from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats from streamlit.time_util import time_to_seconds if TYPE_CHECKING: from datetime import timedelta from streamlit.runtime.caching.hashing import HashFuncsDict _LOGGER: Final = get_logger(__name__) CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE) ValidateFunc: TypeAlias = Callable[[Any], bool] def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool: """True if the two validate functions are equal for the purposes of determining whether a given function cache needs to be recreated. """ # To "properly" test for function equality here, we'd need to compare function bytecode. # For performance reasons, We've decided not to do that for now. return (a is None and b is None) or (a is not None and b is not None) class ResourceCaches(CacheStatsProvider): """Manages all ResourceCache instances.""" def __init__(self) -> None: self._caches_lock = threading.Lock() self._function_caches: dict[str, ResourceCache[Any]] = {} def get_cache( self, key: str, display_name: str, max_entries: int | float | None, ttl: float | timedelta | str | None, validate: ValidateFunc | None, ) -> ResourceCache[Any]: """Return the mem cache for the given key. If it doesn't exist, create a new one with the given params. """ if max_entries is None: max_entries = math.inf ttl_seconds = time_to_seconds(ttl) # Get the existing cache, if it exists, and validate that its params # haven't changed. with self._caches_lock: cache = self._function_caches.get(key) if ( cache is not None and cache.ttl_seconds == ttl_seconds and cache.max_entries == max_entries and _equal_validate_funcs(cache.validate, validate) ): return cache # Create a new cache object and put it in our dict _LOGGER.debug("Creating new ResourceCache (key=%s)", key) cache = ResourceCache( key=key, display_name=display_name, max_entries=max_entries, ttl_seconds=ttl_seconds, validate=validate, ) self._function_caches[key] = cache return cache def clear_all(self) -> None: """Clear all resource caches.""" with self._caches_lock: self._function_caches = {} def get_stats(self) -> list[CacheStat]: with self._caches_lock: # Shallow-clone our caches. We don't want to hold the global # lock during stats-gathering. function_caches = self._function_caches.copy() stats: list[CacheStat] = [] for cache in function_caches.values(): stats.extend(cache.get_stats()) return group_stats(stats) # Singleton ResourceCaches instance _resource_caches = ResourceCaches() def get_resource_cache_stats_provider() -> CacheStatsProvider: """Return the StatsProvider for all @st.cache_resource functions.""" return _resource_caches P = ParamSpec("P") R = TypeVar("R") class CachedResourceFuncInfo(CachedFuncInfo[P, R]): """Implements the CachedFuncInfo interface for @st.cache_resource.""" def __init__( self, func: Callable[P, R], show_spinner: bool | str, max_entries: int | None, ttl: float | timedelta | str | None, validate: ValidateFunc | None, hash_funcs: HashFuncsDict | None = None, show_time: bool = False, ) -> None: super().__init__( func, hash_funcs=hash_funcs, show_spinner=show_spinner, show_time=show_time, ) self.max_entries = max_entries self.ttl = ttl self.validate = validate @property def cache_type(self) -> CacheType: return CacheType.RESOURCE @property def cached_message_replay_ctx(self) -> CachedMessageReplayContext: return CACHE_RESOURCE_MESSAGE_REPLAY_CTX @property def display_name(self) -> str: """A human-readable name for the cached function.""" return f"{self.func.__module__}.{self.func.__qualname__}" def get_function_cache(self, function_key: str) -> Cache[R]: return _resource_caches.get_cache( key=function_key, display_name=self.display_name, max_entries=self.max_entries, ttl=self.ttl, validate=self.validate, ) class CacheResourceAPI: """Implements the public st.cache_resource API: the @st.cache_resource decorator, and st.cache_resource.clear(). """ def __init__(self, decorator_metric_name: str) -> None: """Create a CacheResourceAPI instance. Parameters ---------- decorator_metric_name The metric name to record for decorator usage. """ # Parameterize the decorator metric name. # (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427) self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore # Type-annotate the decorator function. # (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories) # Bare decorator usage @overload def __call__(self, func: Callable[P, R]) -> CachedFunc[P, R]: ... # Decorator with arguments @overload def __call__( self, *, ttl: float | timedelta | str | None = None, max_entries: int | None = None, show_spinner: bool | str = True, show_time: bool = False, validate: ValidateFunc | None = None, experimental_allow_widgets: bool = False, hash_funcs: HashFuncsDict | None = None, ) -> Callable[[Callable[P, R]], CachedFunc[P, R]]: ... def __call__( self, func: Callable[P, R] | None = None, *, ttl: float | timedelta | str | None = None, max_entries: int | None = None, show_spinner: bool | str = True, show_time: bool = False, validate: ValidateFunc | None = None, experimental_allow_widgets: bool = False, hash_funcs: HashFuncsDict | None = None, ) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]: return self._decorator( func, ttl=ttl, max_entries=max_entries, show_spinner=show_spinner, show_time=show_time, validate=validate, experimental_allow_widgets=experimental_allow_widgets, hash_funcs=hash_funcs, ) def _decorator( self, func: Callable[P, R] | None, *, ttl: float | timedelta | str | None, max_entries: int | None, show_spinner: bool | str, show_time: bool = False, validate: ValidateFunc | None, experimental_allow_widgets: bool, hash_funcs: HashFuncsDict | None = None, ) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]: """Decorator to cache functions that return global resources (e.g. database connections, ML models). Cached objects are shared across all users, sessions, and reruns. They must be thread-safe because they can be accessed from multiple threads concurrently. If thread safety is an issue, consider using ``st.session_state`` to store resources per session instead. You can clear a function's cache with ``func.clear()`` or clear the entire cache with ``st.cache_resource.clear()``. A function's arguments must be hashable to cache it. If you have an unhashable argument (like a database connection) or an argument you want to exclude from caching, use an underscore prefix in the argument name. In this case, Streamlit will return a cached value when all other arguments match a previous function call. Alternatively, you can declare custom hashing functions with ``hash_funcs``. To cache data, use ``st.cache_data`` instead. Learn more about caching at https://docs.streamlit.io/develop/concepts/architecture/caching. Parameters ---------- func : callable The function that creates the cached resource. Streamlit hashes the function's source code. ttl : float, timedelta, str, or None The maximum time to keep an entry in the cache. Can be one of: - ``None`` if cache entries should never expire (default). - A number specifying the time in seconds. - A string specifying the time in a format supported by `Pandas's Timedelta constructor `_, e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``. - A ``timedelta`` object from `Python's built-in datetime library `_, e.g. ``timedelta(days=1)``. max_entries : int or None The maximum number of entries to keep in the cache, or None for an unbounded cache. When a new entry is added to a full cache, the oldest cached entry will be removed. Defaults to None. show_spinner : bool or str Enable the spinner. Default is True to show a spinner when there is a "cache miss" and the cached resource is being created. If string, value of show_spinner param will be used for spinner text. show_time : bool Whether to show the elapsed time next to the spinner text. If this is ``False`` (default), no time is displayed. If this is ``True``, elapsed time is displayed with a precision of 0.1 seconds. The time format is not configurable. validate : callable or None An optional validation function for cached data. ``validate`` is called each time the cached value is accessed. It receives the cached value as its only parameter and it must return a boolean. If ``validate`` returns False, the current cached value is discarded, and the decorated function is called to compute a new value. This is useful e.g. to check the health of database connections. experimental_allow_widgets : bool Allow widgets to be used in the cached function. Defaults to False. hash_funcs : dict or None Mapping of types or fully qualified names to hash functions. This is used to override the behavior of the hasher inside Streamlit's caching mechanism: when the hasher encounters an object, it will first check to see if its type matches a key in this dict and, if so, will use the provided function to generate a hash for it. See below for an example of how this can be used. .. deprecated:: The cached widget replay functionality was removed in 1.38. Please remove the ``experimental_allow_widgets`` parameter from your caching decorators. This parameter will be removed in a future version. Example ------- >>> import streamlit as st >>> >>> @st.cache_resource ... def get_database_session(url): ... # Create a database session object that points to the URL. ... return session >>> >>> s1 = get_database_session(SESSION_URL_1) >>> # Actually executes the function, since this is the first time it was >>> # encountered. >>> >>> s2 = get_database_session(SESSION_URL_1) >>> # Does not execute the function. Instead, returns its previously computed >>> # value. This means that now the connection object in s1 is the same as in s2. >>> >>> s3 = get_database_session(SESSION_URL_2) >>> # This is a different URL, so the function executes. By default, all parameters to a cache_resource function must be hashable. Any parameter whose name begins with ``_`` will not be hashed. You can use this as an "escape hatch" for parameters that are not hashable: >>> import streamlit as st >>> >>> @st.cache_resource ... def get_database_session(_sessionmaker, url): ... # Create a database connection object that points to the URL. ... return connection >>> >>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1) >>> # Actually executes the function, since this is the first time it was >>> # encountered. >>> >>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1) >>> # Does not execute the function. Instead, returns its previously computed >>> # value - even though the _sessionmaker parameter was different >>> # in both calls. A cache_resource function's cache can be procedurally cleared: >>> import streamlit as st >>> >>> @st.cache_resource ... def get_database_session(_sessionmaker, url): ... # Create a database connection object that points to the URL. ... return connection >>> >>> fetch_and_clean_data.clear(_sessionmaker, "https://streamlit.io/") >>> # Clear the cached entry for the arguments provided. >>> >>> get_database_session.clear() >>> # Clear all cached entries for this function. To override the default hashing behavior, pass a custom hash function. You can do that by mapping a type (e.g. ``Person``) to a hash function (``str``) like this: >>> import streamlit as st >>> from pydantic import BaseModel >>> >>> class Person(BaseModel): ... name: str >>> >>> @st.cache_resource(hash_funcs={Person: str}) ... def get_person_name(person: Person): ... return person.name Alternatively, you can map the type's fully-qualified name (e.g. ``"__main__.Person"``) to the hash function instead: >>> import streamlit as st >>> from pydantic import BaseModel >>> >>> class Person(BaseModel): ... name: str >>> >>> @st.cache_resource(hash_funcs={"__main__.Person": str}) ... def get_person_name(person: Person): ... return person.name """ if experimental_allow_widgets: show_widget_replay_deprecation("cache_resource") # Support passing the params via function decorator, e.g. # @st.cache_resource(show_spinner=False) if func is None: return lambda f: make_cached_func_wrapper( CachedResourceFuncInfo( func=f, show_spinner=show_spinner, show_time=show_time, max_entries=max_entries, ttl=ttl, validate=validate, hash_funcs=hash_funcs, ) ) return make_cached_func_wrapper( CachedResourceFuncInfo( func=func, show_spinner=show_spinner, show_time=show_time, max_entries=max_entries, ttl=ttl, validate=validate, hash_funcs=hash_funcs, ) ) @gather_metrics("clear_resource_caches") def clear(self) -> None: """Clear all cache_resource caches.""" _resource_caches.clear_all() class ResourceCache(Cache[R]): """Manages cached values for a single st.cache_resource function.""" def __init__( self, key: str, max_entries: float, ttl_seconds: float, validate: ValidateFunc | None, display_name: str, ) -> None: super().__init__() self.key = key self.display_name = display_name self._mem_cache: TTLCache[str, CachedResult[R]] = TTLCache( maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER ) self._mem_cache_lock = threading.Lock() self.validate = validate @property def max_entries(self) -> float: return self._mem_cache.maxsize @property def ttl_seconds(self) -> float: return self._mem_cache.ttl def read_result(self, key: str) -> CachedResult[R]: """Read a value and associated messages from the cache. Raise `CacheKeyNotFoundError` if the value doesn't exist. """ with self._mem_cache_lock: if key not in self._mem_cache: # key does not exist in cache. raise CacheKeyNotFoundError() result = self._mem_cache[key] if self.validate is not None and not self.validate(result.value): # Validate failed: delete the entry and raise an error. del self._mem_cache[key] raise CacheKeyNotFoundError() return result @gather_metrics("_cache_resource_object") def write_result(self, key: str, value: R, messages: list[MsgData]) -> None: """Write a value and associated messages to the cache.""" main_id = st._main.id sidebar_id = st.sidebar.id with self._mem_cache_lock: self._mem_cache[key] = CachedResult(value, messages, main_id, sidebar_id) def _clear(self, key: str | None = None) -> None: with self._mem_cache_lock: if key is None: self._mem_cache.clear() elif key in self._mem_cache: del self._mem_cache[key] def get_stats(self) -> list[CacheStat]: # Shallow clone our cache. Computing item sizes is potentially # expensive, and we want to minimize the time we spend holding # the lock. with self._mem_cache_lock: cache_entries = list(self._mem_cache.values()) # Lazy-load vendored package to prevent import of numpy from streamlit.vendor.pympler.asizeof import asizeof return [ CacheStat( category_name="st_cache_resource", cache_name=self.display_name, byte_length=asizeof(entry), ) for entry in cache_entries ]