team-10/venv/Lib/site-packages/streamlit/web/server/server.py
2025-08-02 02:00:33 +02:00

699 lines
25 KiB
Python

# Copyright 2018-2022 Streamlit Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import errno
import logging
import os
import socket
import sys
import time
import traceback
from enum import Enum
from typing import (
Any,
Dict,
Optional,
Tuple,
Callable,
List,
)
import click
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.httpserver import HTTPServer
from streamlit import config
from streamlit import file_util
from streamlit import source_util
from streamlit import util
from streamlit.runtime.caching import (
get_memo_stats_provider,
get_singleton_stats_provider,
)
from streamlit.components.v1.components import ComponentRegistry
from streamlit.config_option import ConfigOption
from streamlit.runtime.legacy_caching.caching import _mem_caches
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.forward_msg_cache import (
ForwardMsgCache,
create_reference_msg,
populate_hash_if_needed,
)
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
from streamlit.runtime.session_data import SessionData
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SessionStateStatProvider,
)
from streamlit.runtime.stats import StatsManager
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.watcher import LocalSourcesWatcher
from streamlit.web.server.routes import (
AddSlashHandler,
AssetsFileHandler,
HealthHandler,
MediaFileHandler,
MessageCacheHandler,
StaticFileHandler,
)
from streamlit.web.server.server_util import (
get_max_message_size_bytes,
is_cacheable_msg,
make_url_path_regex,
)
from streamlit.web.server.upload_file_request_handler import (
UploadFileRequestHandler,
UPLOAD_FILE_ROUTE,
)
from .browser_websocket_handler import BrowserWebSocketHandler
from .component_request_handler import ComponentRequestHandler
from .session_client import SessionClient, SessionClientDisconnectedError
from .stats_request_handler import StatsRequestHandler
LOGGER = get_logger(__name__)
TORNADO_SETTINGS = {
# Gzip HTTP responses.
"compress_response": True,
# Ping every 1s to keep WS alive.
# 2021.06.22: this value was previously 20s, and was causing
# connection instability for a small number of users. This smaller
# ping_interval fixes that instability.
# https://github.com/streamlit/streamlit/issues/3196
"websocket_ping_interval": 1,
# If we don't get a ping response within 30s, the connection
# is timed out.
"websocket_ping_timeout": 30,
}
# When server.port is not available it will look for the next available port
# up to MAX_PORT_SEARCH_RETRIES.
MAX_PORT_SEARCH_RETRIES = 100
# When server.address starts with this prefix, the server will bind
# to an unix socket.
UNIX_SOCKET_PREFIX = "unix://"
# Wait for the script run result for 60s and if no result is available give up
SCRIPT_RUN_CHECK_TIMEOUT = 60
class SessionInfo:
"""Type stored in our _session_info_by_id dict.
For each AppSession, the server tracks that session's
script_run_count. This is used to track the age of messages in
the ForwardMsgCache.
"""
def __init__(self, client: SessionClient, session: AppSession):
"""Initialize a SessionInfo instance.
Parameters
----------
session : AppSession
The AppSession object.
client : SessionClient
The concrete SessionClient for this session.
"""
self.session = session
self.client = client
self.script_run_count = 0
def __repr__(self) -> str:
return util.repr_(self)
class State(Enum):
INITIAL = "INITIAL"
WAITING_FOR_FIRST_SESSION = "WAITING_FOR_FIRST_SESSION"
ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED"
NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED"
STOPPING = "STOPPING"
STOPPED = "STOPPED"
class RetriesExceeded(Exception):
pass
def server_port_is_manually_set() -> bool:
return config.is_manually_set("server.port")
def server_address_is_unix_socket() -> bool:
address = config.get_option("server.address")
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
def start_listening(app: tornado.web.Application) -> None:
"""Makes the server start listening at the configured port.
In case the port is already taken it tries listening to the next available
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
"""
http_server = HTTPServer(
app, max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024
)
if server_address_is_unix_socket():
start_listening_unix_socket(http_server)
else:
start_listening_tcp_socket(http_server)
def start_listening_unix_socket(http_server: HTTPServer) -> None:
address = config.get_option("server.address")
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
unix_socket = tornado.netutil.bind_unix_socket(file_name)
http_server.add_socket(unix_socket)
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
call_count = 0
port = None
while call_count < MAX_PORT_SEARCH_RETRIES:
address = config.get_option("server.address")
port = config.get_option("server.port")
try:
http_server.listen(port, address)
break # It worked! So let's break out of the loop.
except (OSError, socket.error) as e:
if e.errno == errno.EADDRINUSE:
if server_port_is_manually_set():
LOGGER.error("Port %s is already in use", port)
sys.exit(1)
else:
LOGGER.debug(
"Port %s already in use, trying to use the next one.", port
)
port += 1
# Save port 3000 because it is used for the development
# server in the front end.
if port == 3000:
port += 1
config.set_option(
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
)
call_count += 1
else:
raise
if call_count >= MAX_PORT_SEARCH_RETRIES:
raise RetriesExceeded(
f"Cannot start Streamlit server. Port {port} is already in use, and "
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
)
class Server:
def __init__(self, main_script_path: str, command_line: Optional[str]):
"""Create the server. It won't be started yet."""
_set_tornado_log_levels()
self._main_script_path = main_script_path
self._command_line = command_line if command_line is not None else ""
# Will be set when we start.
self._eventloop: Optional[asyncio.AbstractEventLoop] = None
# Mapping of AppSession.id -> SessionInfo.
self._session_info_by_id: Dict[str, SessionInfo] = {}
self._must_stop = tornado.locks.Event()
self._state = State.INITIAL
self._message_cache = ForwardMsgCache()
self._uploaded_file_mgr = UploadedFileManager()
self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated)
self._session_data: Optional[SessionData] = None
self._has_connection = tornado.locks.Condition()
self._need_send_data = tornado.locks.Event()
# StatsManager
self._stats_mgr = StatsManager()
self._stats_mgr.register_provider(get_memo_stats_provider())
self._stats_mgr.register_provider(get_singleton_stats_provider())
self._stats_mgr.register_provider(_mem_caches)
self._stats_mgr.register_provider(self._message_cache)
self._stats_mgr.register_provider(in_memory_file_manager)
self._stats_mgr.register_provider(self._uploaded_file_mgr)
self._stats_mgr.register_provider(
SessionStateStatProvider(self._session_info_by_id)
)
def __repr__(self) -> str:
return util.repr_(self)
@property
def main_script_path(self) -> str:
return self._main_script_path
def on_files_updated(self, session_id: str) -> None:
"""Event handler for UploadedFileManager.on_file_added.
Ensures that uploaded files from stale sessions get deleted.
"""
session_info = self._get_session_info(session_id)
if session_info is None:
# If an uploaded file doesn't belong to an existing session,
# remove it so it doesn't stick around forever.
self._uploaded_file_mgr.remove_session_files(session_id)
def _get_session_info(self, session_id: str) -> Optional[SessionInfo]:
"""Return the SessionInfo with the given id, or None if no such
session exists.
"""
return self._session_info_by_id.get(session_id, None)
def is_active_session(self, session_id: str) -> bool:
"""True if the session_id belongs to an active session."""
return session_id in self._session_info_by_id
async def start(self, on_started: Callable[["Server"], Any]) -> None:
"""Start the server.
Parameters
----------
on_started : callable
A callback that will be called when the server's run-loop
has started, and the server is ready to begin receiving clients.
"""
if self._state != State.INITIAL:
raise RuntimeError("Server has already been started")
LOGGER.debug("Starting server...")
app = self._create_app()
start_listening(app)
port = config.get_option("server.port")
LOGGER.debug("Server started on port %s", port)
await self._loop_coroutine(on_started)
def _create_app(self) -> tornado.web.Application:
"""Create our tornado web app."""
base = config.get_option("server.baseUrlPath")
routes: List[Any] = [
(
make_url_path_regex(base, "stream"),
BrowserWebSocketHandler,
dict(server=self),
),
(
make_url_path_regex(base, "healthz"),
HealthHandler,
dict(callback=lambda: self.is_ready_for_browser_connection),
),
(
make_url_path_regex(base, "message"),
MessageCacheHandler,
dict(cache=self._message_cache),
),
(
make_url_path_regex(base, "st-metrics"),
StatsRequestHandler,
dict(stats_manager=self._stats_mgr),
),
(
make_url_path_regex(
base,
UPLOAD_FILE_ROUTE,
),
UploadFileRequestHandler,
dict(
file_mgr=self._uploaded_file_mgr,
is_active_session=self.is_active_session,
),
),
(
make_url_path_regex(base, "assets/(.*)"),
AssetsFileHandler,
{"path": "%s/" % file_util.get_assets_dir()},
),
(make_url_path_regex(base, "media/(.*)"), MediaFileHandler, {"path": ""}),
(
make_url_path_regex(base, "component/(.*)"),
ComponentRequestHandler,
dict(registry=ComponentRegistry.instance()),
),
]
if config.get_option("server.scriptHealthCheckEnabled"):
routes.extend(
[
(
make_url_path_regex(base, "script-health-check"),
HealthHandler,
dict(callback=lambda: self.does_script_run_without_error()),
)
]
)
if config.get_option("global.developmentMode"):
LOGGER.debug("Serving static content from the Node dev server")
else:
static_path = file_util.get_static_dir()
LOGGER.debug("Serving static content from %s", static_path)
routes.extend(
[
(
make_url_path_regex(base, "(.*)"),
StaticFileHandler,
{
"path": "%s/" % static_path,
"default_filename": "index.html",
"get_pages": lambda: set(
[
page_info["page_name"]
for page_info in source_util.get_pages(
self.main_script_path
).values()
]
),
},
),
(make_url_path_regex(base, trailing_slash=False), AddSlashHandler),
]
)
return tornado.web.Application(
routes,
cookie_secret=config.get_option("server.cookieSecret"),
xsrf_cookies=config.get_option("server.enableXsrfProtection"),
# Set the websocket message size. The default value is too low.
websocket_max_message_size=get_max_message_size_bytes(),
**TORNADO_SETTINGS, # type: ignore[arg-type]
)
def _set_state(self, new_state: State) -> None:
LOGGER.debug("Server state: %s -> %s" % (self._state, new_state))
self._state = new_state
@property
async def is_ready_for_browser_connection(self) -> Tuple[bool, str]:
if self._state not in (State.INITIAL, State.STOPPING, State.STOPPED):
return True, "ok"
return False, "unavailable"
async def does_script_run_without_error(self) -> Tuple[bool, str]:
"""Load and execute the app's script to verify it runs without an error.
Returns
-------
(True, "ok") if the script completes without error, or (False, err_msg)
if the script raises an exception.
"""
session = AppSession(
event_loop=self._get_eventloop(),
session_data=SessionData(self._main_script_path, self._command_line),
uploaded_file_manager=self._uploaded_file_mgr,
message_enqueued_callback=self._enqueued_some_message,
local_sources_watcher=LocalSourcesWatcher(self._main_script_path),
user_info={"email": "test@test.com"},
)
try:
session.request_rerun(None)
now = time.perf_counter()
while (
SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
):
await asyncio.sleep(0.1)
if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
return False, "timeout"
ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
msg = "ok" if ok else "error"
return ok, msg
finally:
session.shutdown()
@property
def browser_is_connected(self) -> bool:
return self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED
@property
def is_running_hello(self) -> bool:
from streamlit.hello import Hello
return self._main_script_path == Hello.__file__
async def _loop_coroutine(
self, on_started: Optional[Callable[["Server"], Any]] = None
) -> None:
try:
if self._state == State.INITIAL:
self._set_state(State.WAITING_FOR_FIRST_SESSION)
elif self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED:
pass
else:
raise RuntimeError(f"Bad server state at start: {self._state}")
# Store the eventloop we're running on so that we can schedule
# callbacks on it when necessary. (We can't just call
# `asyncio.get_running_loop()` whenever we like, because we have
# some functions, e.g. `stop`, that can be called from other
# threads, and `asyncio.get_running_loop()` is thread-specific.)
self._eventloop = asyncio.get_running_loop()
if on_started is not None:
on_started(self)
while not self._must_stop.is_set():
if self._state == State.WAITING_FOR_FIRST_SESSION:
await asyncio.wait(
[self._must_stop.wait(), self._has_connection.wait()],
return_when=asyncio.FIRST_COMPLETED,
)
elif self._state == State.ONE_OR_MORE_SESSIONS_CONNECTED:
self._need_send_data.clear()
# Shallow-clone our sessions into a list, so we can iterate
# over it and not worry about whether it's being changed
# outside this coroutine.
session_infos = list(self._session_info_by_id.values())
for session_info in session_infos:
msg_list = session_info.session.flush_browser_queue()
for msg in msg_list:
try:
self._send_message(session_info, msg)
except SessionClientDisconnectedError:
self._close_app_session(session_info.session.id)
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0.01)
elif self._state == State.NO_SESSIONS_CONNECTED:
await asyncio.wait(
[self._must_stop.wait(), self._has_connection.wait()],
return_when=asyncio.FIRST_COMPLETED,
)
else:
# Break out of the thread loop if we encounter any other state.
break
await asyncio.wait(
[self._must_stop.wait(), self._need_send_data.wait()],
return_when=asyncio.FIRST_COMPLETED,
)
# Shut down all AppSessions
for session_info in list(self._session_info_by_id.values()):
session_info.session.shutdown()
self._set_state(State.STOPPED)
except Exception:
# Can't just re-raise here because co-routines use Tornado
# exceptions for control flow, which appears to swallow the reraised
# exception.
traceback.print_exc()
LOGGER.info(
"""
Please report this bug at https://github.com/streamlit/streamlit/issues.
"""
)
def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None:
"""Send a message to a client.
If the client is likely to have already cached the message, we may
instead send a "reference" message that contains only the hash of the
message.
Parameters
----------
session_info : SessionInfo
The SessionInfo associated with websocket
msg : ForwardMsg
The message to send to the client
"""
msg.metadata.cacheable = is_cacheable_msg(msg)
msg_to_send = msg
if msg.metadata.cacheable:
populate_hash_if_needed(msg)
if self._message_cache.has_message_reference(
msg, session_info.session, session_info.script_run_count
):
# This session has probably cached this message. Send
# a reference instead.
LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash)
msg_to_send = create_reference_msg(msg)
# Cache the message so it can be referenced in the future.
# If the message is already cached, this will reset its
# age.
LOGGER.debug("Caching message (hash=%s)", msg.hash)
self._message_cache.add_message(
msg, session_info.session, session_info.script_run_count
)
# If this was a `script_finished` message, we increment the
# script_run_count for this session, and update the cache
if (
msg.WhichOneof("type") == "script_finished"
and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
):
LOGGER.debug(
"Script run finished successfully; "
"removing expired entries from MessageCache "
"(max_age=%s)",
config.get_option("global.maxCachedMessageAge"),
)
session_info.script_run_count += 1
self._message_cache.remove_expired_session_entries(
session_info.session, session_info.script_run_count
)
# Ship it off!
session_info.client.write_forward_msg(msg_to_send)
def _enqueued_some_message(self) -> None:
self._get_eventloop().call_soon_threadsafe(self._need_send_data.set)
def stop(self) -> None:
click.secho(" Stopping...", fg="blue")
self._set_state(State.STOPPING)
self._get_eventloop().call_soon_threadsafe(self._must_stop.set)
def _create_app_session(
self, client: SessionClient, user_info: Dict[str, Optional[str]]
) -> AppSession:
"""Register a connected browser with the server.
Parameters
----------
client : SessionClient
The SessionClient for sending data to the session's client.
user_info: Dict
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Returns
-------
AppSession
The newly-created AppSession for this browser connection.
"""
session = AppSession(
event_loop=self._get_eventloop(),
session_data=SessionData(self._main_script_path, self._command_line),
uploaded_file_manager=self._uploaded_file_mgr,
message_enqueued_callback=self._enqueued_some_message,
local_sources_watcher=LocalSourcesWatcher(self._main_script_path),
user_info=user_info,
)
LOGGER.debug(
"Created new session for client %s. Session ID: %s", id(client), session.id
)
assert (
session.id not in self._session_info_by_id
), f"session.id '{session.id}' registered multiple times!"
self._session_info_by_id[session.id] = SessionInfo(client, session)
self._set_state(State.ONE_OR_MORE_SESSIONS_CONNECTED)
self._has_connection.notify_all()
return session
def _close_app_session(self, session_id: str) -> None:
"""Shutdown and remove a AppSession.
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id : str
The AppSession's id string.
"""
if session_id in self._session_info_by_id:
session_info = self._session_info_by_id[session_id]
del self._session_info_by_id[session_id]
session_info.session.shutdown()
if len(self._session_info_by_id) == 0:
self._set_state(State.NO_SESSIONS_CONNECTED)
def _get_eventloop(self) -> asyncio.AbstractEventLoop:
"""Return the asyncio eventloop that the Server was started with.
If the Server hasn't been started, this will raise an error.
"""
if self._eventloop is None:
raise RuntimeError("Server hasn't started yet!")
return self._eventloop
def _set_tornado_log_levels() -> None:
if not config.get_option("global.developmentMode"):
# Hide logs unless they're super important.
# Example of stuff we don't care about: 404 about .js.map files.
logging.getLogger("tornado.access").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)