# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Iterator, MutableMapping from typing import Any, Final from streamlit import logger as _logger from streamlit import runtime from streamlit.elements.lib.utils import Key from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.state.common import require_valid_user_key from streamlit.runtime.state.safe_session_state import SafeSessionState from streamlit.runtime.state.session_state import SessionState _LOGGER: Final = _logger.get_logger(__name__) _state_use_warning_already_displayed: bool = False # The mock session state is used as a fallback if the script is run without `streamlit run` _mock_session_state: SafeSessionState | None = None def get_session_state() -> SafeSessionState: """Get the SessionState object for the current session. Note that in streamlit scripts, this function should not be called directly. Instead, SessionState objects should be accessed via st.session_state. """ global _state_use_warning_already_displayed # noqa: PLW0603 from streamlit.runtime.scriptrunner_utils.script_run_context import ( get_script_run_ctx, ) ctx = get_script_run_ctx() # If there is no script run context because the script is run bare, we # use a global mock session state version to allow bare script execution (via python script.py) if ctx is None: if not _state_use_warning_already_displayed: _state_use_warning_already_displayed = True if not runtime.exists(): _LOGGER.warning( "Session state does not function when running a script without `streamlit run`" ) global _mock_session_state # noqa: PLW0603 if _mock_session_state is None: # Lazy initialize the mock session state _mock_session_state = SafeSessionState(SessionState(), lambda: None) return _mock_session_state return ctx.session_state class SessionStateProxy(MutableMapping[Key, Any]): """A stateless singleton that proxies `st.session_state` interactions to the current script thread's SessionState instance. The proxy API differs slightly from SessionState: it does not allow callers to get, set, or iterate over "keyless" widgets (that is, widgets that were created without a user_key, and have autogenerated keys). """ def __iter__(self) -> Iterator[Any]: """Iterator over user state and keyed widget values.""" # TODO: this is unsafe if fastReruns is true! Let's deprecate/remove. return iter(get_session_state().filtered_state) def __len__(self) -> int: """Number of user state and keyed widget values in session_state.""" return len(get_session_state().filtered_state) def __str__(self) -> str: """String representation of user state and keyed widget values.""" return str(get_session_state().filtered_state) def __getitem__(self, key: Key) -> Any: """Return the state or widget value with the given key. Raises ------ StreamlitAPIException If the key is not a valid SessionState user key. """ key = str(key) require_valid_user_key(key) return get_session_state()[key] @gather_metrics("session_state.set_item") def __setitem__(self, key: Key, value: Any) -> None: """Set the value of the given key. Raises ------ StreamlitAPIException If the key is not a valid SessionState user key. """ key = str(key) require_valid_user_key(key) get_session_state()[key] = value def __delitem__(self, key: Key) -> None: """Delete the value with the given key. Raises ------ StreamlitAPIException If the key is not a valid SessionState user key. """ key = str(key) require_valid_user_key(key) del get_session_state()[key] def __getattr__(self, key: str) -> Any: try: return self[key] except KeyError: raise AttributeError(_missing_attr_error_message(key)) @gather_metrics("session_state.set_attr") def __setattr__(self, key: str, value: Any) -> None: self[key] = value def __delattr__(self, key: str) -> None: try: del self[key] except KeyError: raise AttributeError(_missing_attr_error_message(key)) def to_dict(self) -> dict[str, Any]: """Return a dict containing all session_state and keyed widget values.""" return get_session_state().filtered_state def _missing_attr_error_message(attr_name: str) -> str: return ( f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? ' f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization" )