# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import TYPE_CHECKING, Callable, Final, cast from streamlit.logger import get_logger from streamlit.runtime.app_session import AppSession from streamlit.runtime.session_manager import ( ActiveSessionInfo, SessionClient, SessionInfo, SessionManager, SessionStorage, ) if TYPE_CHECKING: from streamlit.runtime.script_data import ScriptData from streamlit.runtime.scriptrunner.script_cache import ScriptCache from streamlit.runtime.uploaded_file_manager import UploadedFileManager _LOGGER: Final = get_logger(__name__) class WebsocketSessionManager(SessionManager): """A SessionManager used to manage sessions with lifecycles tied to those of a browser tab's websocket connection. WebsocketSessionManagers differentiate between "active" and "inactive" sessions. Active sessions are those with a currently active websocket connection. Inactive sessions are sessions without. Eventual cleanup of inactive sessions is a detail left to the specific SessionStorage that a WebsocketSessionManager is instantiated with. """ def __init__( self, session_storage: SessionStorage, uploaded_file_manager: UploadedFileManager, script_cache: ScriptCache, message_enqueued_callback: Callable[[], None] | None, ) -> None: self._session_storage = session_storage self._uploaded_file_mgr = uploaded_file_manager self._script_cache = script_cache self._message_enqueued_callback = message_enqueued_callback # Mapping of AppSession.id -> ActiveSessionInfo. self._active_session_info_by_id: dict[str, ActiveSessionInfo] = {} def connect_session( self, client: SessionClient, script_data: ScriptData, user_info: dict[str, str | bool | None], existing_session_id: str | None = None, session_id_override: str | None = None, ) -> str: if existing_session_id and session_id_override: raise RuntimeError( "Only one of existing_session_id and session_id_override should be truthy. " "This should never happen." ) if existing_session_id in self._active_session_info_by_id: _LOGGER.warning( "Session with id %s is already connected! Connecting to a new session.", existing_session_id, ) session_info = ( existing_session_id and existing_session_id not in self._active_session_info_by_id and self._session_storage.get(existing_session_id) ) if session_info: existing_session = session_info.session existing_session.register_file_watchers() self._active_session_info_by_id[existing_session.id] = ActiveSessionInfo( client, existing_session, session_info.script_run_count, ) self._session_storage.delete(existing_session.id) return existing_session.id session = AppSession( script_data=script_data, uploaded_file_manager=self._uploaded_file_mgr, script_cache=self._script_cache, message_enqueued_callback=self._message_enqueued_callback, user_info=user_info, session_id_override=session_id_override, ) _LOGGER.debug( "Created new session for client %s. Session ID: %s", id(client), session.id ) if session.id in self._active_session_info_by_id: raise RuntimeError( f"session.id '{session.id}' registered multiple times. " "This should never happen." ) self._active_session_info_by_id[session.id] = ActiveSessionInfo(client, session) return session.id def disconnect_session(self, session_id: str) -> None: if session_id in self._active_session_info_by_id: active_session_info = self._active_session_info_by_id[session_id] session = active_session_info.session session.request_script_stop() session.disconnect_file_watchers() self._session_storage.save( SessionInfo( client=None, session=session, script_run_count=active_session_info.script_run_count, ) ) del self._active_session_info_by_id[session_id] if not self._active_session_info_by_id: # Avoid stale cached scripts when all file watchers and sessions are disconnected self._script_cache.clear() def get_active_session_info(self, session_id: str) -> ActiveSessionInfo | None: return self._active_session_info_by_id.get(session_id) def is_active_session(self, session_id: str) -> bool: return session_id in self._active_session_info_by_id def list_active_sessions(self) -> list[ActiveSessionInfo]: return list(self._active_session_info_by_id.values()) def close_session(self, session_id: str) -> None: if session_id in self._active_session_info_by_id: active_session_info = self._active_session_info_by_id[session_id] del self._active_session_info_by_id[session_id] active_session_info.session.shutdown() if not self._active_session_info_by_id: # Avoid stale cached scripts when all file watchers and sessions are disconnected self._script_cache.clear() return session_info = self._session_storage.get(session_id) if session_info: self._session_storage.delete(session_id) session_info.session.shutdown() def get_session_info(self, session_id: str) -> SessionInfo | None: session_info = self.get_active_session_info(session_id) if session_info: return cast("SessionInfo", session_info) return self._session_storage.get(session_id) def list_sessions(self) -> list[SessionInfo]: return ( cast("list[SessionInfo]", self.list_active_sessions()) + self._session_storage.list() )