# Copyright 2018-2022 Streamlit Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import binascii import json from typing import ( Any, Dict, Optional, Awaitable, Union, TYPE_CHECKING, ) import tornado.concurrent import tornado.locks import tornado.netutil import tornado.web import tornado.websocket from tornado.websocket import WebSocketHandler from typing_extensions import Final from streamlit import config from streamlit.runtime.app_session import AppSession from streamlit.logger import get_logger from streamlit.proto.BackMsg_pb2 import BackMsg from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.web.server.server_util import is_url_from_allowed_origins from streamlit.web.server.server_util import serialize_forward_msg from .session_client import SessionClient, SessionClientDisconnectedError if TYPE_CHECKING: from .server import Server LOGGER: Final = get_logger(__name__) class BrowserWebSocketHandler(WebSocketHandler, SessionClient): """Handles a WebSocket connection from the browser""" def initialize(self, server: "Server") -> None: self._server = server self._session: Optional[AppSession] = None # The XSRF cookie is normally set when xsrf_form_html is used, but in a # pure-Javascript application that does not use any regular forms we just # need to read the self.xsrf_token manually to set the cookie as a side # effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection # for more details. if config.get_option("server.enableXsrfProtection"): _ = self.xsrf_token def check_origin(self, origin: str) -> bool: """Set up CORS.""" return super().check_origin(origin) or is_url_from_allowed_origins(origin) def write_forward_msg(self, msg: ForwardMsg) -> None: """Send a ForwardMsg to the browser.""" try: self.write_message(serialize_forward_msg(msg), binary=True) except tornado.websocket.WebSocketClosedError as e: raise SessionClientDisconnectedError from e def open(self, *args, **kwargs) -> Optional[Awaitable[None]]: # Extract user info from the X-Streamlit-User header is_public_cloud_app = False try: header_content = self.request.headers["X-Streamlit-User"] payload = base64.b64decode(header_content) user_obj = json.loads(payload) email = user_obj["email"] is_public_cloud_app = user_obj["isPublicCloudApp"] except (KeyError, binascii.Error, json.decoder.JSONDecodeError): email = "test@localhost.com" user_info: Dict[str, Optional[str]] = dict() if is_public_cloud_app: user_info["email"] = None else: user_info["email"] = email self._session = self._server._create_app_session(self, user_info) return None def on_close(self) -> None: if not self._session: return self._server._close_app_session(self._session.id) self._session = None def get_compression_options(self) -> Optional[Dict[Any, Any]]: """Enable WebSocket compression. Returning an empty dict enables websocket compression. Returning None disables it. (See the docstring in the parent class.) """ if config.get_option("server.enableWebsocketCompression"): return {} return None def on_message(self, payload: Union[str, bytes]) -> None: if not self._session: return msg = BackMsg() try: if isinstance(payload, str): # Sanity check. (The frontend should only be sending us bytes; # Protobuf.ParseFromString does not accept str input.) raise RuntimeError( "WebSocket received an unexpected `str` message. " "(We expect `bytes` only.)" ) msg.ParseFromString(payload) LOGGER.debug("Received the following back message:\n%s", msg) if msg.WhichOneof("type") == "close_connection": # "close_connection" is a special developmentMode-only # message used in e2e tests to test disabling widgets. if config.get_option("global.developmentMode"): self._server.stop() else: LOGGER.warning( "Client tried to close connection when " "not in development mode" ) else: # AppSession handles all other BackMsg types. self._session.handle_backmsg(msg) except BaseException as e: LOGGER.error(e) self._session.handle_backmsg_exception(e)