149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
# Copyright 2018-2022 Streamlit Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import base64
|
|
import binascii
|
|
import json
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Optional,
|
|
Awaitable,
|
|
Union,
|
|
TYPE_CHECKING,
|
|
)
|
|
|
|
import tornado.concurrent
|
|
import tornado.locks
|
|
import tornado.netutil
|
|
import tornado.web
|
|
import tornado.websocket
|
|
from tornado.websocket import WebSocketHandler
|
|
from typing_extensions import Final
|
|
|
|
from streamlit import config
|
|
from streamlit.runtime.app_session import AppSession
|
|
from streamlit.logger import get_logger
|
|
from streamlit.proto.BackMsg_pb2 import BackMsg
|
|
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
|
from streamlit.web.server.server_util import is_url_from_allowed_origins
|
|
from streamlit.web.server.server_util import serialize_forward_msg
|
|
from .session_client import SessionClient, SessionClientDisconnectedError
|
|
|
|
if TYPE_CHECKING:
|
|
from .server import Server
|
|
|
|
LOGGER: Final = get_logger(__name__)
|
|
|
|
|
|
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
|
|
"""Handles a WebSocket connection from the browser"""
|
|
|
|
def initialize(self, server: "Server") -> None:
|
|
self._server = server
|
|
self._session: Optional[AppSession] = None
|
|
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
|
|
# pure-Javascript application that does not use any regular forms we just
|
|
# need to read the self.xsrf_token manually to set the cookie as a side
|
|
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
|
|
# for more details.
|
|
if config.get_option("server.enableXsrfProtection"):
|
|
_ = self.xsrf_token
|
|
|
|
def check_origin(self, origin: str) -> bool:
|
|
"""Set up CORS."""
|
|
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
|
|
|
|
def write_forward_msg(self, msg: ForwardMsg) -> None:
|
|
"""Send a ForwardMsg to the browser."""
|
|
try:
|
|
self.write_message(serialize_forward_msg(msg), binary=True)
|
|
except tornado.websocket.WebSocketClosedError as e:
|
|
raise SessionClientDisconnectedError from e
|
|
|
|
def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
|
|
# Extract user info from the X-Streamlit-User header
|
|
is_public_cloud_app = False
|
|
|
|
try:
|
|
header_content = self.request.headers["X-Streamlit-User"]
|
|
payload = base64.b64decode(header_content)
|
|
user_obj = json.loads(payload)
|
|
email = user_obj["email"]
|
|
is_public_cloud_app = user_obj["isPublicCloudApp"]
|
|
except (KeyError, binascii.Error, json.decoder.JSONDecodeError):
|
|
email = "test@localhost.com"
|
|
|
|
user_info: Dict[str, Optional[str]] = dict()
|
|
if is_public_cloud_app:
|
|
user_info["email"] = None
|
|
else:
|
|
user_info["email"] = email
|
|
|
|
self._session = self._server._create_app_session(self, user_info)
|
|
return None
|
|
|
|
def on_close(self) -> None:
|
|
if not self._session:
|
|
return
|
|
self._server._close_app_session(self._session.id)
|
|
self._session = None
|
|
|
|
def get_compression_options(self) -> Optional[Dict[Any, Any]]:
|
|
"""Enable WebSocket compression.
|
|
|
|
Returning an empty dict enables websocket compression. Returning
|
|
None disables it.
|
|
|
|
(See the docstring in the parent class.)
|
|
"""
|
|
if config.get_option("server.enableWebsocketCompression"):
|
|
return {}
|
|
return None
|
|
|
|
def on_message(self, payload: Union[str, bytes]) -> None:
|
|
if not self._session:
|
|
return
|
|
|
|
msg = BackMsg()
|
|
|
|
try:
|
|
if isinstance(payload, str):
|
|
# Sanity check. (The frontend should only be sending us bytes;
|
|
# Protobuf.ParseFromString does not accept str input.)
|
|
raise RuntimeError(
|
|
"WebSocket received an unexpected `str` message. "
|
|
"(We expect `bytes` only.)"
|
|
)
|
|
|
|
msg.ParseFromString(payload)
|
|
LOGGER.debug("Received the following back message:\n%s", msg)
|
|
|
|
if msg.WhichOneof("type") == "close_connection":
|
|
# "close_connection" is a special developmentMode-only
|
|
# message used in e2e tests to test disabling widgets.
|
|
if config.get_option("global.developmentMode"):
|
|
self._server.stop()
|
|
else:
|
|
LOGGER.warning(
|
|
"Client tried to close connection when "
|
|
"not in development mode"
|
|
)
|
|
else:
|
|
# AppSession handles all other BackMsg types.
|
|
self._session.handle_backmsg(msg)
|
|
|
|
except BaseException as e:
|
|
LOGGER.error(e)
|
|
self._session.handle_backmsg_exception(e)
|