team-10/env/Lib/site-packages/streamlit/runtime/caching/cache_resource_api.py
2025-08-02 07:34:44 +02:00

550 lines
20 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_resource implementation."""
from __future__ import annotations
import math
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
TypeVar,
overload,
)
from cachetools import TTLCache
from typing_extensions import ParamSpec, TypeAlias
import streamlit as st
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFunc,
CachedFuncInfo,
make_cached_func_wrapper,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
show_widget_replay_deprecation,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
from streamlit.time_util import time_to_seconds
if TYPE_CHECKING:
from datetime import timedelta
from streamlit.runtime.caching.hashing import HashFuncsDict
_LOGGER: Final = get_logger(__name__)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
ValidateFunc: TypeAlias = Callable[[Any], bool]
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
"""True if the two validate functions are equal for the purposes of
determining whether a given function cache needs to be recreated.
"""
# To "properly" test for function equality here, we'd need to compare function bytecode.
# For performance reasons, We've decided not to do that for now.
return (a is None and b is None) or (a is not None and b is not None)
class ResourceCaches(CacheStatsProvider):
"""Manages all ResourceCache instances."""
def __init__(self) -> None:
self._caches_lock = threading.Lock()
self._function_caches: dict[str, ResourceCache[Any]] = {}
def get_cache(
self,
key: str,
display_name: str,
max_entries: int | float | None,
ttl: float | timedelta | str | None,
validate: ValidateFunc | None,
) -> ResourceCache[Any]:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
ttl_seconds = time_to_seconds(ttl)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and _equal_validate_funcs(cache.validate, validate)
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
cache = ResourceCache(
key=key,
display_name=display_name,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
validate=validate,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all resource caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return group_stats(stats)
# Singleton ResourceCaches instance
_resource_caches = ResourceCaches()
def get_resource_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_resource functions."""
return _resource_caches
P = ParamSpec("P")
R = TypeVar("R")
class CachedResourceFuncInfo(CachedFuncInfo[P, R]):
"""Implements the CachedFuncInfo interface for @st.cache_resource."""
def __init__(
self,
func: Callable[P, R],
show_spinner: bool | str,
max_entries: int | None,
ttl: float | timedelta | str | None,
validate: ValidateFunc | None,
hash_funcs: HashFuncsDict | None = None,
show_time: bool = False,
) -> None:
super().__init__(
func,
hash_funcs=hash_funcs,
show_spinner=show_spinner,
show_time=show_time,
)
self.max_entries = max_entries
self.ttl = ttl
self.validate = validate
@property
def cache_type(self) -> CacheType:
return CacheType.RESOURCE
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function."""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache[R]:
return _resource_caches.get_cache(
key=function_key,
display_name=self.display_name,
max_entries=self.max_entries,
ttl=self.ttl,
validate=self.validate,
)
class CacheResourceAPI:
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
and st.cache_resource.clear().
"""
def __init__(self, decorator_metric_name: str) -> None:
"""Create a CacheResourceAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
# Bare decorator usage
@overload
def __call__(self, func: Callable[P, R]) -> CachedFunc[P, R]: ...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
show_time: bool = False,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
) -> Callable[[Callable[P, R]], CachedFunc[P, R]]: ...
def __call__(
self,
func: Callable[P, R] | None = None,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
show_time: bool = False,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]:
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
show_spinner=show_spinner,
show_time=show_time,
validate=validate,
experimental_allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
def _decorator(
self,
func: Callable[P, R] | None,
*,
ttl: float | timedelta | str | None,
max_entries: int | None,
show_spinner: bool | str,
show_time: bool = False,
validate: ValidateFunc | None,
experimental_allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]:
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
Cached objects are shared across all users, sessions, and reruns. They
must be thread-safe because they can be accessed from multiple threads
concurrently. If thread safety is an issue, consider using ``st.session_state``
to store resources per session instead.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_resource.clear()``.
A function's arguments must be hashable to cache it. If you have an
unhashable argument (like a database connection) or an argument you
want to exclude from caching, use an underscore prefix in the argument
name. In this case, Streamlit will return a cached value when all other
arguments match a previous function call. Alternatively, you can
declare custom hashing functions with ``hash_funcs``.
To cache data, use ``st.cache_data`` instead. Learn more about caching at
https://docs.streamlit.io/develop/concepts/architecture/caching.
Parameters
----------
func : callable
The function that creates the cached resource. Streamlit hashes the
function's source code.
ttl : float, timedelta, str, or None
The maximum time to keep an entry in the cache. Can be one of:
- ``None`` if cache entries should never expire (default).
- A number specifying the time in seconds.
- A string specifying the time in a format supported by `Pandas's
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
- A ``timedelta`` object from `Python's built-in datetime library
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
e.g. ``timedelta(days=1)``.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. When a new entry is added to a full cache,
the oldest cached entry will be removed. Defaults to None.
show_spinner : bool or str
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached resource is being created. If string,
value of show_spinner param will be used for spinner text.
show_time : bool
Whether to show the elapsed time next to the spinner text. If this is
``False`` (default), no time is displayed. If this is ``True``,
elapsed time is displayed with a precision of 0.1 seconds. The time
format is not configurable.
validate : callable or None
An optional validation function for cached data. ``validate`` is called
each time the cached value is accessed. It receives the cached value as
its only parameter and it must return a boolean. If ``validate`` returns
False, the current cached value is discarded, and the decorated function
is called to compute a new value. This is useful e.g. to check the
health of database connections.
experimental_allow_widgets : bool
Allow widgets to be used in the cached function. Defaults to False.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions.
This is used to override the behavior of the hasher inside Streamlit's
caching mechanism: when the hasher encounters an object, it will first
check to see if its type matches a key in this dict and, if so, will use
the provided function to generate a hash for it. See below for an example
of how this can be used.
.. deprecated::
The cached widget replay functionality was removed in 1.38. Please
remove the ``experimental_allow_widgets`` parameter from your
caching decorators. This parameter will be removed in a future
version.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
>>>
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a cache_resource function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
>>>
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A cache_resource function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
>>>
>>> fetch_and_clean_data.clear(_sessionmaker, "https://streamlit.io/")
>>> # Clear the cached entry for the arguments provided.
>>>
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``Person``) to a hash
function (``str``) like this:
>>> import streamlit as st
>>> from pydantic import BaseModel
>>>
>>> class Person(BaseModel):
... name: str
>>>
>>> @st.cache_resource(hash_funcs={Person: str})
... def get_person_name(person: Person):
... return person.name
Alternatively, you can map the type's fully-qualified name
(e.g. ``"__main__.Person"``) to the hash function instead:
>>> import streamlit as st
>>> from pydantic import BaseModel
>>>
>>> class Person(BaseModel):
... name: str
>>>
>>> @st.cache_resource(hash_funcs={"__main__.Person": str})
... def get_person_name(person: Person):
... return person.name
"""
if experimental_allow_widgets:
show_widget_replay_deprecation("cache_resource")
# Support passing the params via function decorator, e.g.
# @st.cache_resource(show_spinner=False)
if func is None:
return lambda f: make_cached_func_wrapper(
CachedResourceFuncInfo(
func=f,
show_spinner=show_spinner,
show_time=show_time,
max_entries=max_entries,
ttl=ttl,
validate=validate,
hash_funcs=hash_funcs,
)
)
return make_cached_func_wrapper(
CachedResourceFuncInfo(
func=func,
show_spinner=show_spinner,
show_time=show_time,
max_entries=max_entries,
ttl=ttl,
validate=validate,
hash_funcs=hash_funcs,
)
)
@gather_metrics("clear_resource_caches")
def clear(self) -> None:
"""Clear all cache_resource caches."""
_resource_caches.clear_all()
class ResourceCache(Cache[R]):
"""Manages cached values for a single st.cache_resource function."""
def __init__(
self,
key: str,
max_entries: float,
ttl_seconds: float,
validate: ValidateFunc | None,
display_name: str,
) -> None:
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TTLCache[str, CachedResult[R]] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
self.validate = validate
@property
def max_entries(self) -> float:
return self._mem_cache.maxsize
@property
def ttl_seconds(self) -> float:
return self._mem_cache.ttl
def read_result(self, key: str) -> CachedResult[R]:
"""Read a value and associated messages from the cache.
Raise `CacheKeyNotFoundError` if the value doesn't exist.
"""
with self._mem_cache_lock:
if key not in self._mem_cache:
# key does not exist in cache.
raise CacheKeyNotFoundError()
result = self._mem_cache[key]
if self.validate is not None and not self.validate(result.value):
# Validate failed: delete the entry and raise an error.
del self._mem_cache[key]
raise CacheKeyNotFoundError()
return result
@gather_metrics("_cache_resource_object")
def write_result(self, key: str, value: R, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache."""
main_id = st._main.id
sidebar_id = st.sidebar.id
with self._mem_cache_lock:
self._mem_cache[key] = CachedResult(value, messages, main_id, sidebar_id)
def _clear(self, key: str | None = None) -> None:
with self._mem_cache_lock:
if key is None:
self._mem_cache.clear()
elif key in self._mem_cache:
del self._mem_cache[key]
def get_stats(self) -> list[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
cache_entries = list(self._mem_cache.values())
# Lazy-load vendored package to prevent import of numpy
from streamlit.vendor.pympler.asizeof import asizeof
return [
CacheStat(
category_name="st_cache_resource",
cache_name=self.display_name,
byte_length=asizeof(entry),
)
for entry in cache_entries
]