# Copyright 2018-2022 Streamlit Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import errno import logging import os import socket import sys import time import traceback from enum import Enum from typing import ( Any, Dict, Optional, Tuple, Callable, List, ) import click import tornado.concurrent import tornado.locks import tornado.netutil import tornado.web import tornado.websocket from tornado.httpserver import HTTPServer from streamlit import config from streamlit import file_util from streamlit import source_util from streamlit import util from streamlit.runtime.caching import ( get_memo_stats_provider, get_singleton_stats_provider, ) from streamlit.components.v1.components import ComponentRegistry from streamlit.config_option import ConfigOption from streamlit.runtime.legacy_caching.caching import _mem_caches from streamlit.logger import get_logger from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.runtime.app_session import AppSession from streamlit.runtime.forward_msg_cache import ( ForwardMsgCache, create_reference_msg, populate_hash_if_needed, ) from streamlit.runtime.in_memory_file_manager import in_memory_file_manager from streamlit.runtime.session_data import SessionData from streamlit.runtime.state import ( SCRIPT_RUN_WITHOUT_ERRORS_KEY, SessionStateStatProvider, ) from streamlit.runtime.stats import StatsManager from streamlit.runtime.uploaded_file_manager import UploadedFileManager from streamlit.watcher import LocalSourcesWatcher from streamlit.web.server.routes import ( AddSlashHandler, AssetsFileHandler, HealthHandler, MediaFileHandler, MessageCacheHandler, StaticFileHandler, ) from streamlit.web.server.server_util import ( get_max_message_size_bytes, is_cacheable_msg, make_url_path_regex, ) from streamlit.web.server.upload_file_request_handler import ( UploadFileRequestHandler, UPLOAD_FILE_ROUTE, ) from .browser_websocket_handler import BrowserWebSocketHandler from .component_request_handler import ComponentRequestHandler from .session_client import SessionClient, SessionClientDisconnectedError from .stats_request_handler import StatsRequestHandler LOGGER = get_logger(__name__) TORNADO_SETTINGS = { # Gzip HTTP responses. "compress_response": True, # Ping every 1s to keep WS alive. # 2021.06.22: this value was previously 20s, and was causing # connection instability for a small number of users. This smaller # ping_interval fixes that instability. # https://github.com/streamlit/streamlit/issues/3196 "websocket_ping_interval": 1, # If we don't get a ping response within 30s, the connection # is timed out. "websocket_ping_timeout": 30, } # When server.port is not available it will look for the next available port # up to MAX_PORT_SEARCH_RETRIES. MAX_PORT_SEARCH_RETRIES = 100 # When server.address starts with this prefix, the server will bind # to an unix socket. UNIX_SOCKET_PREFIX = "unix://" # Wait for the script run result for 60s and if no result is available give up SCRIPT_RUN_CHECK_TIMEOUT = 60 class SessionInfo: """Type stored in our _session_info_by_id dict. For each AppSession, the server tracks that session's script_run_count. This is used to track the age of messages in the ForwardMsgCache. """ def __init__(self, client: SessionClient, session: AppSession): """Initialize a SessionInfo instance. Parameters ---------- session : AppSession The AppSession object. client : SessionClient The concrete SessionClient for this session. """ self.session = session self.client = client self.script_run_count = 0 def __repr__(self) -> str: return util.repr_(self) class State(Enum): INITIAL = "INITIAL" WAITING_FOR_FIRST_SESSION = "WAITING_FOR_FIRST_SESSION" ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED" NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED" STOPPING = "STOPPING" STOPPED = "STOPPED" class RetriesExceeded(Exception): pass def server_port_is_manually_set() -> bool: return config.is_manually_set("server.port") def server_address_is_unix_socket() -> bool: address = config.get_option("server.address") return address is not None and address.startswith(UNIX_SOCKET_PREFIX) def start_listening(app: tornado.web.Application) -> None: """Makes the server start listening at the configured port. In case the port is already taken it tries listening to the next available port. It will error after MAX_PORT_SEARCH_RETRIES attempts. """ http_server = HTTPServer( app, max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024 ) if server_address_is_unix_socket(): start_listening_unix_socket(http_server) else: start_listening_tcp_socket(http_server) def start_listening_unix_socket(http_server: HTTPServer) -> None: address = config.get_option("server.address") file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :]) unix_socket = tornado.netutil.bind_unix_socket(file_name) http_server.add_socket(unix_socket) def start_listening_tcp_socket(http_server: HTTPServer) -> None: call_count = 0 port = None while call_count < MAX_PORT_SEARCH_RETRIES: address = config.get_option("server.address") port = config.get_option("server.port") try: http_server.listen(port, address) break # It worked! So let's break out of the loop. except (OSError, socket.error) as e: if e.errno == errno.EADDRINUSE: if server_port_is_manually_set(): LOGGER.error("Port %s is already in use", port) sys.exit(1) else: LOGGER.debug( "Port %s already in use, trying to use the next one.", port ) port += 1 # Save port 3000 because it is used for the development # server in the front end. if port == 3000: port += 1 config.set_option( "server.port", port, ConfigOption.STREAMLIT_DEFINITION ) call_count += 1 else: raise if call_count >= MAX_PORT_SEARCH_RETRIES: raise RetriesExceeded( f"Cannot start Streamlit server. Port {port} is already in use, and " f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.", ) class Server: def __init__(self, main_script_path: str, command_line: Optional[str]): """Create the server. It won't be started yet.""" _set_tornado_log_levels() self._main_script_path = main_script_path self._command_line = command_line if command_line is not None else "" # Will be set when we start. self._eventloop: Optional[asyncio.AbstractEventLoop] = None # Mapping of AppSession.id -> SessionInfo. self._session_info_by_id: Dict[str, SessionInfo] = {} self._must_stop = tornado.locks.Event() self._state = State.INITIAL self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated) self._session_data: Optional[SessionData] = None self._has_connection = tornado.locks.Condition() self._need_send_data = tornado.locks.Event() # StatsManager self._stats_mgr = StatsManager() self._stats_mgr.register_provider(get_memo_stats_provider()) self._stats_mgr.register_provider(get_singleton_stats_provider()) self._stats_mgr.register_provider(_mem_caches) self._stats_mgr.register_provider(self._message_cache) self._stats_mgr.register_provider(in_memory_file_manager) self._stats_mgr.register_provider(self._uploaded_file_mgr) self._stats_mgr.register_provider( SessionStateStatProvider(self._session_info_by_id) ) def __repr__(self) -> str: return util.repr_(self) @property def main_script_path(self) -> str: return self._main_script_path def on_files_updated(self, session_id: str) -> None: """Event handler for UploadedFileManager.on_file_added. Ensures that uploaded files from stale sessions get deleted. """ session_info = self._get_session_info(session_id) if session_info is None: # If an uploaded file doesn't belong to an existing session, # remove it so it doesn't stick around forever. self._uploaded_file_mgr.remove_session_files(session_id) def _get_session_info(self, session_id: str) -> Optional[SessionInfo]: """Return the SessionInfo with the given id, or None if no such session exists. """ return self._session_info_by_id.get(session_id, None) def is_active_session(self, session_id: str) -> bool: """True if the session_id belongs to an active session.""" return session_id in self._session_info_by_id async def start(self, on_started: Callable[["Server"], Any]) -> None: """Start the server. Parameters ---------- on_started : callable A callback that will be called when the server's run-loop has started, and the server is ready to begin receiving clients. """ if self._state != State.INITIAL: raise RuntimeError("Server has already been started") LOGGER.debug("Starting server...") app = self._create_app() start_listening(app) port = config.get_option("server.port") LOGGER.debug("Server started on port %s", port) await self._loop_coroutine(on_started) def _create_app(self) -> tornado.web.Application: """Create our tornado web app.""" base = config.get_option("server.baseUrlPath") routes: List[Any] = [ ( make_url_path_regex(base, "stream"), BrowserWebSocketHandler, dict(server=self), ), ( make_url_path_regex(base, "healthz"), HealthHandler, dict(callback=lambda: self.is_ready_for_browser_connection), ), ( make_url_path_regex(base, "message"), MessageCacheHandler, dict(cache=self._message_cache), ), ( make_url_path_regex(base, "st-metrics"), StatsRequestHandler, dict(stats_manager=self._stats_mgr), ), ( make_url_path_regex( base, UPLOAD_FILE_ROUTE, ), UploadFileRequestHandler, dict( file_mgr=self._uploaded_file_mgr, is_active_session=self.is_active_session, ), ), ( make_url_path_regex(base, "assets/(.*)"), AssetsFileHandler, {"path": "%s/" % file_util.get_assets_dir()}, ), (make_url_path_regex(base, "media/(.*)"), MediaFileHandler, {"path": ""}), ( make_url_path_regex(base, "component/(.*)"), ComponentRequestHandler, dict(registry=ComponentRegistry.instance()), ), ] if config.get_option("server.scriptHealthCheckEnabled"): routes.extend( [ ( make_url_path_regex(base, "script-health-check"), HealthHandler, dict(callback=lambda: self.does_script_run_without_error()), ) ] ) if config.get_option("global.developmentMode"): LOGGER.debug("Serving static content from the Node dev server") else: static_path = file_util.get_static_dir() LOGGER.debug("Serving static content from %s", static_path) routes.extend( [ ( make_url_path_regex(base, "(.*)"), StaticFileHandler, { "path": "%s/" % static_path, "default_filename": "index.html", "get_pages": lambda: set( [ page_info["page_name"] for page_info in source_util.get_pages( self.main_script_path ).values() ] ), }, ), (make_url_path_regex(base, trailing_slash=False), AddSlashHandler), ] ) return tornado.web.Application( routes, cookie_secret=config.get_option("server.cookieSecret"), xsrf_cookies=config.get_option("server.enableXsrfProtection"), # Set the websocket message size. The default value is too low. websocket_max_message_size=get_max_message_size_bytes(), **TORNADO_SETTINGS, # type: ignore[arg-type] ) def _set_state(self, new_state: State) -> None: LOGGER.debug("Server state: %s -> %s" % (self._state, new_state)) self._state = new_state @property async def is_ready_for_browser_connection(self) -> Tuple[bool, str]: if self._state not in (State.INITIAL, State.STOPPING, State.STOPPED): return True, "ok" return False, "unavailable" async def does_script_run_without_error(self) -> Tuple[bool, str]: """Load and execute the app's script to verify it runs without an error. Returns ------- (True, "ok") if the script completes without error, or (False, err_msg) if the script raises an exception. """ session = AppSession( event_loop=self._get_eventloop(), session_data=SessionData(self._main_script_path, self._command_line), uploaded_file_manager=self._uploaded_file_mgr, message_enqueued_callback=self._enqueued_some_message, local_sources_watcher=LocalSourcesWatcher(self._main_script_path), user_info={"email": "test@test.com"}, ) try: session.request_rerun(None) now = time.perf_counter() while ( SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT ): await asyncio.sleep(0.1) if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state: return False, "timeout" ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] msg = "ok" if ok else "error" return ok, msg finally: session.shutdown() @property def browser_is_connected(self) -> bool: return self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED @property def is_running_hello(self) -> bool: from streamlit.hello import Hello return self._main_script_path == Hello.__file__ async def _loop_coroutine( self, on_started: Optional[Callable[["Server"], Any]] = None ) -> None: try: if self._state == State.INITIAL: self._set_state(State.WAITING_FOR_FIRST_SESSION) elif self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED: pass else: raise RuntimeError(f"Bad server state at start: {self._state}") # Store the eventloop we're running on so that we can schedule # callbacks on it when necessary. (We can't just call # `asyncio.get_running_loop()` whenever we like, because we have # some functions, e.g. `stop`, that can be called from other # threads, and `asyncio.get_running_loop()` is thread-specific.) self._eventloop = asyncio.get_running_loop() if on_started is not None: on_started(self) while not self._must_stop.is_set(): if self._state == State.WAITING_FOR_FIRST_SESSION: await asyncio.wait( [self._must_stop.wait(), self._has_connection.wait()], return_when=asyncio.FIRST_COMPLETED, ) elif self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED: self._need_send_data.clear() # Shallow-clone our sessions into a list, so we can iterate # over it and not worry about whether it's being changed # outside this coroutine. session_infos = list(self._session_info_by_id.values()) for session_info in session_infos: msg_list = session_info.session.flush_browser_queue() for msg in msg_list: try: self._send_message(session_info, msg) except SessionClientDisconnectedError: self._close_app_session(session_info.session.id) await asyncio.sleep(0) await asyncio.sleep(0) await asyncio.sleep(0.01) elif self._state == State.NO_SESSIONS_CONNECTED: await asyncio.wait( [self._must_stop.wait(), self._has_connection.wait()], return_when=asyncio.FIRST_COMPLETED, ) else: # Break out of the thread loop if we encounter any other state. break await asyncio.wait( [self._must_stop.wait(), self._need_send_data.wait()], return_when=asyncio.FIRST_COMPLETED, ) # Shut down all AppSessions for session_info in list(self._session_info_by_id.values()): session_info.session.shutdown() self._set_state(State.STOPPED) except Exception: # Can't just re-raise here because co-routines use Tornado # exceptions for control flow, which appears to swallow the reraised # exception. traceback.print_exc() LOGGER.info( """ Please report this bug at https://github.com/streamlit/streamlit/issues. """ ) def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None: """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : SessionInfo The SessionInfo associated with websocket msg : ForwardMsg The message to send to the client """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.script_run_count ): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)", msg.hash) self._message_cache.add_message( msg, session_info.session, session_info.script_run_count ) # If this was a `script_finished` message, we increment the # script_run_count for this session, and update the cache if ( msg.WhichOneof("type") == "script_finished" and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY ): LOGGER.debug( "Script run finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.script_run_count += 1 self._message_cache.remove_expired_session_entries( session_info.session, session_info.script_run_count ) # Ship it off! session_info.client.write_forward_msg(msg_to_send) def _enqueued_some_message(self) -> None: self._get_eventloop().call_soon_threadsafe(self._need_send_data.set) def stop(self) -> None: click.secho(" Stopping...", fg="blue") self._set_state(State.STOPPING) self._get_eventloop().call_soon_threadsafe(self._must_stop.set) def _create_app_session( self, client: SessionClient, user_info: Dict[str, Optional[str]] ) -> AppSession: """Register a connected browser with the server. Parameters ---------- client : SessionClient The SessionClient for sending data to the session's client. user_info: Dict A dict that contains information about the current user. For now, it only contains the user's email address. { "email": "example@example.com" } Returns ------- AppSession The newly-created AppSession for this browser connection. """ session = AppSession( event_loop=self._get_eventloop(), session_data=SessionData(self._main_script_path, self._command_line), uploaded_file_manager=self._uploaded_file_mgr, message_enqueued_callback=self._enqueued_some_message, local_sources_watcher=LocalSourcesWatcher(self._main_script_path), user_info=user_info, ) LOGGER.debug( "Created new session for client %s. Session ID: %s", id(client), session.id ) assert ( session.id not in self._session_info_by_id ), f"session.id '{session.id}' registered multiple times!" self._session_info_by_id[session.id] = SessionInfo(client, session) self._set_state(State.ONE_OR_MORE_SESSIONS_CONNECTED) self._has_connection.notify_all() return session def _close_app_session(self, session_id: str) -> None: """Shutdown and remove a AppSession. This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id : str The AppSession's id string. """ if session_id in self._session_info_by_id: session_info = self._session_info_by_id[session_id] del self._session_info_by_id[session_id] session_info.session.shutdown() if len(self._session_info_by_id) == 0: self._set_state(State.NO_SESSIONS_CONNECTED) def _get_eventloop(self) -> asyncio.AbstractEventLoop: """Return the asyncio eventloop that the Server was started with. If the Server hasn't been started, this will raise an error. """ if self._eventloop is None: raise RuntimeError("Server hasn't started yet!") return self._eventloop def _set_tornado_log_levels() -> None: if not config.get_option("global.developmentMode"): # Hide logs unless they're super important. # Example of stuff we don't care about: 404 about .js.map files. logging.getLogger("tornado.access").setLevel(logging.ERROR) logging.getLogger("tornado.application").setLevel(logging.ERROR) logging.getLogger("tornado.general").setLevel(logging.ERROR)