303 lines
11 KiB
Python
303 lines
11 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
|
|
|
|
from typing_extensions import ParamSpec
|
|
|
|
import streamlit as st
|
|
from streamlit import runtime, util
|
|
from streamlit.deprecation_util import show_deprecation_warning
|
|
from streamlit.runtime.caching.cache_errors import CacheReplayClosureError
|
|
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
|
in_cached_function,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator
|
|
|
|
from google.protobuf.message import Message
|
|
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
from streamlit.proto.Block_pb2 import Block
|
|
from streamlit.runtime.caching.cache_type import CacheType
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MediaMsgData:
|
|
media: bytes | str
|
|
mimetype: str
|
|
media_id: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ElementMsgData:
|
|
"""An element's message and related metadata for
|
|
replaying that element's function call.
|
|
|
|
media_data is filled in iff this is a media element (image, audio, video).
|
|
"""
|
|
|
|
delta_type: str
|
|
message: Message
|
|
id_of_dg_called_on: str
|
|
returned_dgs_id: str
|
|
media_data: list[MediaMsgData] | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BlockMsgData:
|
|
message: Block
|
|
id_of_dg_called_on: str
|
|
returned_dgs_id: str
|
|
|
|
|
|
MsgData = Union[ElementMsgData, BlockMsgData]
|
|
|
|
|
|
R = TypeVar("R")
|
|
|
|
|
|
@dataclass
|
|
class CachedResult(Generic[R]):
|
|
"""The full results of calling a cache-decorated function, enough to
|
|
replay the st functions called while executing it.
|
|
"""
|
|
|
|
value: R
|
|
messages: list[MsgData]
|
|
main_id: str
|
|
sidebar_id: str
|
|
|
|
|
|
"""
|
|
Note [DeltaGenerator method invocation]
|
|
There are two top level DG instances defined for all apps:
|
|
`main`, which is for putting elements in the main part of the app
|
|
`sidebar`, for the sidebar
|
|
|
|
There are 3 different ways an st function can be invoked:
|
|
1. Implicitly on the main DG instance (plain `st.foo` calls)
|
|
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
|
|
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
|
|
|
|
To simplify replaying messages from a cached function result, we convert all of these
|
|
to explicit invocations. How they get rewritten depends on if the invocation was
|
|
implicit vs explicit, and if the target DG has been seen/produced during replay.
|
|
|
|
Implicit invocation on a known DG -> Explicit invocation on that DG
|
|
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
|
|
with st.container():
|
|
my_cache_decorated_function()
|
|
|
|
This is situation 2 above, and the DG is a block entirely outside our function call,
|
|
so we interpret it as "put this element in the enclosing contextmanager block"
|
|
(or main if there isn't one), which is achieved by invoking on main.
|
|
Explicit invocation on a known DG -> No change needed
|
|
Explicit invocation on an unknown DG -> Raise an error
|
|
We have no way to identify the target DG, and it may not even be present in the
|
|
current script run, so the least surprising thing to do is raise an error.
|
|
|
|
"""
|
|
|
|
|
|
class CachedMessageReplayContext(threading.local):
|
|
"""A utility for storing messages generated by `st` commands called inside
|
|
a cached function.
|
|
|
|
Data is stored in a thread-local object, so it's safe to use an instance
|
|
of this class across multiple threads.
|
|
"""
|
|
|
|
def __init__(self, cache_type: CacheType) -> None:
|
|
self._cached_message_stack: list[list[MsgData]] = []
|
|
self._seen_dg_stack: list[set[str]] = []
|
|
self._most_recent_messages: list[MsgData] = []
|
|
self._media_data: list[MediaMsgData] = []
|
|
self._cache_type = cache_type
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
@contextlib.contextmanager
|
|
def calling_cached_function(self, func: Callable[..., Any]) -> Iterator[None]: # noqa: ARG002
|
|
"""Context manager that should wrap the invocation of a cached function.
|
|
It allows us to track any `st.foo` messages that are generated from inside the
|
|
function for playback during cache retrieval.
|
|
"""
|
|
self._cached_message_stack.append([])
|
|
self._seen_dg_stack.append(set())
|
|
nested_call = False
|
|
if in_cached_function.get():
|
|
nested_call = True
|
|
# If we're in a cached function. To disallow usage of widget-like element,
|
|
# we need to set the in_cached_function to true for this cached function run
|
|
# to prevent widget usage (triggers a warning).
|
|
in_cached_function.set(True)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
self._most_recent_messages = self._cached_message_stack.pop()
|
|
self._seen_dg_stack.pop()
|
|
if not nested_call:
|
|
# Reset the in_cached_function flag. But only if this
|
|
# is not nested inside a cached function that disallows widget usage.
|
|
in_cached_function.set(False)
|
|
|
|
def save_element_message(
|
|
self,
|
|
delta_type: str,
|
|
element_proto: Message,
|
|
invoked_dg_id: str,
|
|
used_dg_id: str,
|
|
returned_dg_id: str,
|
|
) -> None:
|
|
"""Record the element protobuf as having been produced during any currently
|
|
executing cached functions, so they can be replayed any time the function's
|
|
execution is skipped because they're in the cache.
|
|
"""
|
|
if not runtime.exists() or not in_cached_function.get():
|
|
return
|
|
|
|
if len(self._cached_message_stack) >= 1:
|
|
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
|
|
|
media_data = self._media_data
|
|
|
|
element_msg_data = ElementMsgData(
|
|
delta_type,
|
|
element_proto,
|
|
id_to_save,
|
|
returned_dg_id,
|
|
media_data,
|
|
)
|
|
for msgs in self._cached_message_stack:
|
|
msgs.append(element_msg_data)
|
|
|
|
# Reset instance state, now that it has been used for the
|
|
# associated element.
|
|
self._media_data = []
|
|
|
|
for s in self._seen_dg_stack:
|
|
s.add(returned_dg_id)
|
|
|
|
def save_block_message(
|
|
self,
|
|
block_proto: Block,
|
|
invoked_dg_id: str,
|
|
used_dg_id: str,
|
|
returned_dg_id: str,
|
|
) -> None:
|
|
if not in_cached_function.get():
|
|
return
|
|
|
|
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
|
for msgs in self._cached_message_stack:
|
|
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
|
|
for s in self._seen_dg_stack:
|
|
s.add(returned_dg_id)
|
|
|
|
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
|
|
"""Select the id of the DG that this message should be invoked on
|
|
during message replay.
|
|
|
|
See Note [DeltaGenerator method invocation]
|
|
|
|
invoked_id is the DG the st function was called on, usually `st._main`.
|
|
acting_on_id is the DG the st function ultimately runs on, which may be different
|
|
if the invoked DG delegated to another one because it was in a `with` block.
|
|
"""
|
|
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
|
|
return acting_on_id
|
|
return invoked_id
|
|
|
|
def save_media_data(
|
|
self, media_data: bytes | str, mimetype: str, media_id: str
|
|
) -> None:
|
|
if not in_cached_function.get():
|
|
return
|
|
|
|
self._media_data.append(MediaMsgData(media_data, mimetype, media_id))
|
|
|
|
|
|
P = ParamSpec("P")
|
|
|
|
|
|
def replay_cached_messages(
|
|
result: CachedResult[R], cache_type: CacheType, cached_func: Callable[P, R]
|
|
) -> None:
|
|
"""Replay the st element function calls that happened when executing a
|
|
cache-decorated function.
|
|
|
|
When a cache function is executed, we record the element and block messages
|
|
produced, and use those to reproduce the DeltaGenerator calls, so the elements
|
|
will appear in the web app even when execution of the function is skipped
|
|
because the result was cached.
|
|
|
|
To make this work, for each st function call we record an identifier for the
|
|
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
|
|
We also record the identifier for each DG returned by an st function call, if
|
|
it returns one. Then, for each recorded message, we get the current DG instance
|
|
corresponding to the DG the message was originally called on, and enqueue the
|
|
message using that, recording any new DGs produced in case a later st function
|
|
call is on one of them.
|
|
"""
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
|
|
# Maps originally recorded dg ids to this script run's version of that dg
|
|
returned_dgs: dict[str, DeltaGenerator] = {
|
|
result.main_id: st._main,
|
|
result.sidebar_id: st.sidebar,
|
|
}
|
|
try:
|
|
for msg in result.messages:
|
|
if isinstance(msg, ElementMsgData):
|
|
if msg.media_data is not None:
|
|
for data in msg.media_data:
|
|
runtime.get_instance().media_file_mgr.add(
|
|
data.media, data.mimetype, data.media_id
|
|
)
|
|
dg = returned_dgs[msg.id_of_dg_called_on]
|
|
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
|
|
if isinstance(maybe_dg, DeltaGenerator):
|
|
returned_dgs[msg.returned_dgs_id] = maybe_dg
|
|
elif isinstance(msg, BlockMsgData):
|
|
dg = returned_dgs[msg.id_of_dg_called_on]
|
|
new_dg = dg._block(msg.message)
|
|
returned_dgs[msg.returned_dgs_id] = new_dg
|
|
except KeyError as ex:
|
|
raise CacheReplayClosureError(cache_type, cached_func) from ex
|
|
|
|
|
|
def show_widget_replay_deprecation(
|
|
decorator: Literal["cache_data", "cache_resource"],
|
|
) -> None:
|
|
show_deprecation_warning(
|
|
"The cached widget replay feature was removed in 1.38. The "
|
|
"`experimental_allow_widgets` parameter will also be removed "
|
|
"in a future release. Please remove the `experimental_allow_widgets` parameter "
|
|
f"from the `@st.{decorator}` decorator and move all widget commands outside of "
|
|
"cached functions.\n\nTo speed up your app, we recommend moving your widgets "
|
|
"into fragments. Find out more about fragments in "
|
|
"[our docs](https://docs.streamlit.io/develop/api-reference/execution-flow/st.fragment). "
|
|
"\n\nIf you have a specific use-case that requires the "
|
|
"`experimental_allow_widgets` functionality, please tell us via an "
|
|
"[issue on Github](https://github.com/streamlit/streamlit/issues)."
|
|
)
|