311 lines
12 KiB
Python
311 lines
12 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from dataclasses import dataclass, field, replace
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
from streamlit import util
|
|
from streamlit.proto.Common_pb2 import ChatInputValue as ChatInputValueProto
|
|
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
|
|
|
|
if TYPE_CHECKING:
|
|
from streamlit.proto.ClientState_pb2 import ContextInfo
|
|
|
|
|
|
class ScriptRequestType(Enum):
|
|
# The ScriptRunner should continue running its script.
|
|
CONTINUE = "CONTINUE"
|
|
|
|
# If the script is running, it should be stopped as soon
|
|
# as the ScriptRunner reaches an interrupt point.
|
|
# This is a terminal state.
|
|
STOP = "STOP"
|
|
|
|
# A script rerun has been requested. The ScriptRunner should
|
|
# handle this request as soon as it reaches an interrupt point.
|
|
RERUN = "RERUN"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RerunData:
|
|
"""Data attached to RERUN requests. Immutable."""
|
|
|
|
query_string: str = ""
|
|
widget_states: WidgetStates | None = None
|
|
page_script_hash: str = ""
|
|
page_name: str = ""
|
|
|
|
# A single fragment_id to append to fragment_id_queue.
|
|
fragment_id: str | None = None
|
|
# The queue of fragment_ids waiting to be run.
|
|
fragment_id_queue: list[str] = field(default_factory=list)
|
|
is_fragment_scoped_rerun: bool = False
|
|
# set to true when a script is rerun by the fragment auto-rerun mechanism
|
|
is_auto_rerun: bool = False
|
|
# Hashes of messages that are cached in the client browser:
|
|
cached_message_hashes: set[str] = field(default_factory=set)
|
|
# context_info is used to store information from the user browser (e.g. timezone)
|
|
context_info: ContextInfo | None = None
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScriptRequest:
|
|
"""A STOP or RERUN request and associated data."""
|
|
|
|
type: ScriptRequestType
|
|
_rerun_data: RerunData | None = None
|
|
|
|
@property
|
|
def rerun_data(self) -> RerunData:
|
|
if self.type is not ScriptRequestType.RERUN:
|
|
raise RuntimeError("RerunData is only set for RERUN requests.")
|
|
return cast("RerunData", self._rerun_data)
|
|
|
|
def __repr__(self) -> str:
|
|
return util.repr_(self)
|
|
|
|
|
|
def _fragment_run_should_not_preempt_script(
|
|
fragment_id_queue: list[str],
|
|
is_fragment_scoped_rerun: bool,
|
|
) -> bool:
|
|
"""Returns whether the currently running script should be preempted due to a
|
|
fragment rerun.
|
|
|
|
Reruns corresponding to fragment runs that weren't caused by calls to
|
|
`st.rerun(scope="fragment")` should *not* cancel the current script run
|
|
as doing so will affect elements outside of the fragment.
|
|
"""
|
|
return bool(fragment_id_queue) and not is_fragment_scoped_rerun
|
|
|
|
|
|
def _coalesce_widget_states(
|
|
old_states: WidgetStates | None, new_states: WidgetStates | None
|
|
) -> WidgetStates | None:
|
|
"""Coalesce an older WidgetStates into a newer one, and return a new
|
|
WidgetStates containing the result.
|
|
|
|
For most widget values, we just take the latest version.
|
|
|
|
However, any trigger_values (which are set by buttons) that are True in
|
|
`old_states` will be set to True in the coalesced result, so that button
|
|
presses don't go missing.
|
|
"""
|
|
if not old_states and not new_states:
|
|
return None
|
|
if not old_states:
|
|
return new_states
|
|
if not new_states:
|
|
return old_states
|
|
|
|
states_by_id: dict[str, WidgetState] = {
|
|
wstate.id: wstate for wstate in new_states.widgets
|
|
}
|
|
|
|
trigger_value_types = [
|
|
("trigger_value", False),
|
|
("chat_input_value", ChatInputValueProto(data=None)),
|
|
]
|
|
for old_state in old_states.widgets:
|
|
for trigger_value_type, unset_value in trigger_value_types:
|
|
if (
|
|
old_state.WhichOneof("value") == trigger_value_type
|
|
and getattr(old_state, trigger_value_type) != unset_value
|
|
):
|
|
new_trigger_val = states_by_id.get(old_state.id)
|
|
# It should nearly always be the case that new_trigger_val is None
|
|
# here as trigger values are deleted from the client's WidgetStateManager
|
|
# as soon as a rerun_script BackMsg is sent to the server. Since it's
|
|
# impossible to test that the client sends us state in the expected
|
|
# format in a unit test, we test for this behavior in
|
|
# e2e_playwright/test_fragment_queue_test.py
|
|
if not new_trigger_val or (
|
|
# Ensure the corresponding new_state is also a trigger;
|
|
# otherwise, a widget that was previously a button/chat_input but no
|
|
# longer is could get a bad value.
|
|
new_trigger_val.WhichOneof("value") == trigger_value_type
|
|
# We only want to take the value of old_state if new_trigger_val is
|
|
# unset as the old value may be stale if a newer one was entered.
|
|
and getattr(new_trigger_val, trigger_value_type) == unset_value
|
|
):
|
|
states_by_id[old_state.id] = old_state
|
|
|
|
coalesced = WidgetStates()
|
|
coalesced.widgets.extend(states_by_id.values())
|
|
|
|
return coalesced
|
|
|
|
|
|
class ScriptRequests:
|
|
"""An interface for communicating with a ScriptRunner. Thread-safe.
|
|
|
|
AppSession makes requests of a ScriptRunner through this class, and
|
|
ScriptRunner handles those requests.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = threading.Lock()
|
|
self._state = ScriptRequestType.CONTINUE
|
|
self._rerun_data = RerunData()
|
|
|
|
def request_stop(self) -> None:
|
|
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
|
|
can't be used anymore. STOP requests succeed unconditionally.
|
|
"""
|
|
with self._lock:
|
|
self._state = ScriptRequestType.STOP
|
|
|
|
def request_rerun(self, new_data: RerunData) -> bool:
|
|
"""Request that the ScriptRunner rerun its script.
|
|
|
|
If the ScriptRunner has been stopped, this request can't be honored:
|
|
return False.
|
|
|
|
Otherwise, record the request and return True. The ScriptRunner will
|
|
handle the rerun request as soon as it reaches an interrupt point.
|
|
"""
|
|
|
|
with self._lock:
|
|
if self._state == ScriptRequestType.STOP:
|
|
# We can't rerun after being stopped.
|
|
return False
|
|
|
|
if self._state == ScriptRequestType.CONTINUE:
|
|
# The script is currently running, and we haven't received a request to
|
|
# rerun it as of yet. We can handle a rerun request unconditionally so
|
|
# just change self._state and set self._rerun_data.
|
|
self._state = ScriptRequestType.RERUN
|
|
|
|
# Convert from a single fragment_id into fragment_id_queue.
|
|
if new_data.fragment_id:
|
|
new_data = replace(
|
|
new_data,
|
|
fragment_id=None,
|
|
fragment_id_queue=[new_data.fragment_id],
|
|
)
|
|
|
|
self._rerun_data = new_data
|
|
return True
|
|
|
|
if self._state == ScriptRequestType.RERUN:
|
|
# We already have an existing Rerun request, so we can coalesce the new
|
|
# rerun request into the existing one.
|
|
|
|
coalesced_states = _coalesce_widget_states(
|
|
self._rerun_data.widget_states, new_data.widget_states
|
|
)
|
|
|
|
if new_data.fragment_id:
|
|
# This RERUN request corresponds to a new fragment run. We append
|
|
# the new fragment ID to the end of the current fragment_id_queue if
|
|
# it isn't already contained in it.
|
|
fragment_id_queue = [*self._rerun_data.fragment_id_queue]
|
|
|
|
if new_data.fragment_id not in fragment_id_queue:
|
|
fragment_id_queue.append(new_data.fragment_id)
|
|
elif new_data.fragment_id_queue:
|
|
# new_data contains a new fragment_id_queue, so we just use it.
|
|
fragment_id_queue = new_data.fragment_id_queue
|
|
else:
|
|
# Otherwise, this is a request to rerun the full script, so we want
|
|
# to clear out any fragments we have queued to run since they'll all
|
|
# be run with the full script anyway.
|
|
fragment_id_queue = []
|
|
|
|
self._rerun_data = RerunData(
|
|
query_string=new_data.query_string,
|
|
widget_states=coalesced_states,
|
|
page_script_hash=new_data.page_script_hash,
|
|
page_name=new_data.page_name,
|
|
fragment_id_queue=fragment_id_queue,
|
|
cached_message_hashes=new_data.cached_message_hashes,
|
|
is_fragment_scoped_rerun=new_data.is_fragment_scoped_rerun,
|
|
is_auto_rerun=new_data.is_auto_rerun,
|
|
context_info=new_data.context_info,
|
|
)
|
|
|
|
return True
|
|
|
|
# We'll never get here
|
|
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
|
|
|
|
def on_scriptrunner_yield(self) -> ScriptRequest | None:
|
|
"""Called by the ScriptRunner when it's at a yield point.
|
|
|
|
If we have no request or a RERUN request corresponding to one or more fragments
|
|
(that is not a fragment-scoped rerun), return None.
|
|
|
|
If we have a (full script or fragment-scoped) RERUN request, return the request
|
|
and set our internal state to CONTINUE.
|
|
|
|
If we have a STOP request, return the request and remain stopped.
|
|
"""
|
|
if self._state == ScriptRequestType.CONTINUE or (
|
|
self._state == ScriptRequestType.RERUN
|
|
and _fragment_run_should_not_preempt_script(
|
|
self._rerun_data.fragment_id_queue,
|
|
self._rerun_data.is_fragment_scoped_rerun,
|
|
)
|
|
):
|
|
# We avoid taking the lock in the common cases described above. If a STOP or
|
|
# preempting RERUN request is received after we've taken this code path, it
|
|
# will be handled at the next `on_scriptrunner_yield`, or when
|
|
# `on_scriptrunner_ready` is called.
|
|
return None
|
|
|
|
with self._lock:
|
|
if self._state == ScriptRequestType.RERUN:
|
|
# We already made this check in the fast-path above but need to do so
|
|
# again in case our state changed while we were waiting on the lock.
|
|
if _fragment_run_should_not_preempt_script(
|
|
self._rerun_data.fragment_id_queue,
|
|
self._rerun_data.is_fragment_scoped_rerun,
|
|
):
|
|
return None
|
|
|
|
self._state = ScriptRequestType.CONTINUE
|
|
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
|
|
|
if self._state != ScriptRequestType.STOP:
|
|
raise RuntimeError(
|
|
f"Unrecognized ScriptRunnerState: {self._state}. This should never happen."
|
|
)
|
|
return ScriptRequest(ScriptRequestType.STOP)
|
|
|
|
def on_scriptrunner_ready(self) -> ScriptRequest:
|
|
"""Called by the ScriptRunner when it's about to run its script for
|
|
the first time, and also after its script has successfully completed.
|
|
|
|
If we have a RERUN request, return the request and set
|
|
our internal state to CONTINUE.
|
|
|
|
If we have a STOP request or no request, set our internal state
|
|
to STOP.
|
|
"""
|
|
with self._lock:
|
|
if self._state == ScriptRequestType.RERUN:
|
|
self._state = ScriptRequestType.CONTINUE
|
|
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
|
|
|
# If we don't have a rerun request, unconditionally change our
|
|
# state to STOP.
|
|
self._state = ScriptRequestType.STOP
|
|
return ScriptRequest(ScriptRequestType.STOP)
|