team-10/venv/Lib/site-packages/streamlit/web/server/browser_websocket_handler.py
2025-08-02 02:00:33 +02:00

149 lines
5.3 KiB
Python

# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import binascii
import json
from typing import (
Any,
Dict,
Optional,
Awaitable,
Union,
TYPE_CHECKING,
)
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.websocket import WebSocketHandler
from typing_extensions import Final
from streamlit import config
from streamlit.runtime.app_session import AppSession
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.web.server.server_util import is_url_from_allowed_origins
from streamlit.web.server.server_util import serialize_forward_msg
from .session_client import SessionClient, SessionClientDisconnectedError
if TYPE_CHECKING:
from .server import Server
LOGGER: Final = get_logger(__name__)
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
"""Handles a WebSocket connection from the browser"""
def initialize(self, server: "Server") -> None:
self._server = server
self._session: Optional[AppSession] = None
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
# pure-Javascript application that does not use any regular forms we just
# need to read the self.xsrf_token manually to set the cookie as a side
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
# for more details.
if config.get_option("server.enableXsrfProtection"):
_ = self.xsrf_token
def check_origin(self, origin: str) -> bool:
"""Set up CORS."""
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Send a ForwardMsg to the browser."""
try:
self.write_message(serialize_forward_msg(msg), binary=True)
except tornado.websocket.WebSocketClosedError as e:
raise SessionClientDisconnectedError from e
def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
# Extract user info from the X-Streamlit-User header
is_public_cloud_app = False
try:
header_content = self.request.headers["X-Streamlit-User"]
payload = base64.b64decode(header_content)
user_obj = json.loads(payload)
email = user_obj["email"]
is_public_cloud_app = user_obj["isPublicCloudApp"]
except (KeyError, binascii.Error, json.decoder.JSONDecodeError):
email = "test@localhost.com"
user_info: Dict[str, Optional[str]] = dict()
if is_public_cloud_app:
user_info["email"] = None
else:
user_info["email"] = email
self._session = self._server._create_app_session(self, user_info)
return None
def on_close(self) -> None:
if not self._session:
return
self._server._close_app_session(self._session.id)
self._session = None
def get_compression_options(self) -> Optional[Dict[Any, Any]]:
"""Enable WebSocket compression.
Returning an empty dict enables websocket compression. Returning
None disables it.
(See the docstring in the parent class.)
"""
if config.get_option("server.enableWebsocketCompression"):
return {}
return None
def on_message(self, payload: Union[str, bytes]) -> None:
if not self._session:
return
msg = BackMsg()
try:
if isinstance(payload, str):
# Sanity check. (The frontend should only be sending us bytes;
# Protobuf.ParseFromString does not accept str input.)
raise RuntimeError(
"WebSocket received an unexpected `str` message. "
"(We expect `bytes` only.)"
)
msg.ParseFromString(payload)
LOGGER.debug("Received the following back message:\n%s", msg)
if msg.WhichOneof("type") == "close_connection":
# "close_connection" is a special developmentMode-only
# message used in e2e tests to test disabling widgets.
if config.get_option("global.developmentMode"):
self._server.stop()
else:
LOGGER.warning(
"Client tried to close connection when "
"not in development mode"
)
else:
# AppSession handles all other BackMsg types.
self._session.handle_backmsg(msg)
except BaseException as e:
LOGGER.error(e)
self._session.handle_backmsg_exception(e)