team-10/env/Lib/site-packages/streamlit/runtime/caching/cached_message_replay.py
2025-08-02 07:34:44 +02:00

303 lines
11 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
from typing_extensions import ParamSpec
import streamlit as st
from streamlit import runtime, util
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.runtime.caching.cache_errors import CacheReplayClosureError
from streamlit.runtime.scriptrunner_utils.script_run_context import (
in_cached_function,
)
if TYPE_CHECKING:
from collections.abc import Iterator
from google.protobuf.message import Message
from streamlit.delta_generator import DeltaGenerator
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_type import CacheType
@dataclass(frozen=True)
class MediaMsgData:
media: bytes | str
mimetype: str
media_id: str
@dataclass(frozen=True)
class ElementMsgData:
"""An element's message and related metadata for
replaying that element's function call.
media_data is filled in iff this is a media element (image, audio, video).
"""
delta_type: str
message: Message
id_of_dg_called_on: str
returned_dgs_id: str
media_data: list[MediaMsgData] | None = None
@dataclass(frozen=True)
class BlockMsgData:
message: Block
id_of_dg_called_on: str
returned_dgs_id: str
MsgData = Union[ElementMsgData, BlockMsgData]
R = TypeVar("R")
@dataclass
class CachedResult(Generic[R]):
"""The full results of calling a cache-decorated function, enough to
replay the st functions called while executing it.
"""
value: R
messages: list[MsgData]
main_id: str
sidebar_id: str
"""
Note [DeltaGenerator method invocation]
There are two top level DG instances defined for all apps:
`main`, which is for putting elements in the main part of the app
`sidebar`, for the sidebar
There are 3 different ways an st function can be invoked:
1. Implicitly on the main DG instance (plain `st.foo` calls)
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
To simplify replaying messages from a cached function result, we convert all of these
to explicit invocations. How they get rewritten depends on if the invocation was
implicit vs explicit, and if the target DG has been seen/produced during replay.
Implicit invocation on a known DG -> Explicit invocation on that DG
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
with st.container():
my_cache_decorated_function()
This is situation 2 above, and the DG is a block entirely outside our function call,
so we interpret it as "put this element in the enclosing contextmanager block"
(or main if there isn't one), which is achieved by invoking on main.
Explicit invocation on a known DG -> No change needed
Explicit invocation on an unknown DG -> Raise an error
We have no way to identify the target DG, and it may not even be present in the
current script run, so the least surprising thing to do is raise an error.
"""
class CachedMessageReplayContext(threading.local):
"""A utility for storing messages generated by `st` commands called inside
a cached function.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType) -> None:
self._cached_message_stack: list[list[MsgData]] = []
self._seen_dg_stack: list[set[str]] = []
self._most_recent_messages: list[MsgData] = []
self._media_data: list[MediaMsgData] = []
self._cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(self, func: Callable[..., Any]) -> Iterator[None]: # noqa: ARG002
"""Context manager that should wrap the invocation of a cached function.
It allows us to track any `st.foo` messages that are generated from inside the
function for playback during cache retrieval.
"""
self._cached_message_stack.append([])
self._seen_dg_stack.append(set())
nested_call = False
if in_cached_function.get():
nested_call = True
# If we're in a cached function. To disallow usage of widget-like element,
# we need to set the in_cached_function to true for this cached function run
# to prevent widget usage (triggers a warning).
in_cached_function.set(True)
try:
yield
finally:
self._most_recent_messages = self._cached_message_stack.pop()
self._seen_dg_stack.pop()
if not nested_call:
# Reset the in_cached_function flag. But only if this
# is not nested inside a cached function that disallows widget usage.
in_cached_function.set(False)
def save_element_message(
self,
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Record the element protobuf as having been produced during any currently
executing cached functions, so they can be replayed any time the function's
execution is skipped because they're in the cache.
"""
if not runtime.exists() or not in_cached_function.get():
return
if len(self._cached_message_stack) >= 1:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
media_data = self._media_data
element_msg_data = ElementMsgData(
delta_type,
element_proto,
id_to_save,
returned_dg_id,
media_data,
)
for msgs in self._cached_message_stack:
msgs.append(element_msg_data)
# Reset instance state, now that it has been used for the
# associated element.
self._media_data = []
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def save_block_message(
self,
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
if not in_cached_function.get():
return
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
for msgs in self._cached_message_stack:
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
"""Select the id of the DG that this message should be invoked on
during message replay.
See Note [DeltaGenerator method invocation]
invoked_id is the DG the st function was called on, usually `st._main`.
acting_on_id is the DG the st function ultimately runs on, which may be different
if the invoked DG delegated to another one because it was in a `with` block.
"""
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
return acting_on_id
return invoked_id
def save_media_data(
self, media_data: bytes | str, mimetype: str, media_id: str
) -> None:
if not in_cached_function.get():
return
self._media_data.append(MediaMsgData(media_data, mimetype, media_id))
P = ParamSpec("P")
def replay_cached_messages(
result: CachedResult[R], cache_type: CacheType, cached_func: Callable[P, R]
) -> None:
"""Replay the st element function calls that happened when executing a
cache-decorated function.
When a cache function is executed, we record the element and block messages
produced, and use those to reproduce the DeltaGenerator calls, so the elements
will appear in the web app even when execution of the function is skipped
because the result was cached.
To make this work, for each st function call we record an identifier for the
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
We also record the identifier for each DG returned by an st function call, if
it returns one. Then, for each recorded message, we get the current DG instance
corresponding to the DG the message was originally called on, and enqueue the
message using that, recording any new DGs produced in case a later st function
call is on one of them.
"""
from streamlit.delta_generator import DeltaGenerator
# Maps originally recorded dg ids to this script run's version of that dg
returned_dgs: dict[str, DeltaGenerator] = {
result.main_id: st._main,
result.sidebar_id: st.sidebar,
}
try:
for msg in result.messages:
if isinstance(msg, ElementMsgData):
if msg.media_data is not None:
for data in msg.media_data:
runtime.get_instance().media_file_mgr.add(
data.media, data.mimetype, data.media_id
)
dg = returned_dgs[msg.id_of_dg_called_on]
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
if isinstance(maybe_dg, DeltaGenerator):
returned_dgs[msg.returned_dgs_id] = maybe_dg
elif isinstance(msg, BlockMsgData):
dg = returned_dgs[msg.id_of_dg_called_on]
new_dg = dg._block(msg.message)
returned_dgs[msg.returned_dgs_id] = new_dg
except KeyError as ex:
raise CacheReplayClosureError(cache_type, cached_func) from ex
def show_widget_replay_deprecation(
decorator: Literal["cache_data", "cache_resource"],
) -> None:
show_deprecation_warning(
"The cached widget replay feature was removed in 1.38. The "
"`experimental_allow_widgets` parameter will also be removed "
"in a future release. Please remove the `experimental_allow_widgets` parameter "
f"from the `@st.{decorator}` decorator and move all widget commands outside of "
"cached functions.\n\nTo speed up your app, we recommend moving your widgets "
"into fragments. Find out more about fragments in "
"[our docs](https://docs.streamlit.io/develop/api-reference/execution-flow/st.fragment). "
"\n\nIf you have a specific use-case that requires the "
"`experimental_allow_widgets` functionality, please tell us via an "
"[issue on Github](https://github.com/streamlit/streamlit/issues)."
)