194 lines
5.9 KiB
Python
194 lines
5.9 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Server related utility functions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Callable, Final, Literal, cast
|
|
from urllib.parse import urljoin
|
|
|
|
from streamlit import config, net_util, url_util
|
|
from streamlit.runtime.secrets import secrets_singleton
|
|
from streamlit.type_util import is_version_less_than
|
|
|
|
if TYPE_CHECKING:
|
|
from tornado.web import RequestHandler
|
|
|
|
# The port used for internal development.
|
|
DEVELOPMENT_PORT: Final = 3000
|
|
|
|
AUTH_COOKIE_NAME: Final = "_streamlit_user"
|
|
|
|
|
|
def allowlisted_origins() -> set[str]:
|
|
return {origin.strip() for origin in config.get_option("server.corsAllowedOrigins")}
|
|
|
|
|
|
def is_tornado_version_less_than(v: str) -> bool:
|
|
"""Return True if the current Tornado version is less than the input version.
|
|
|
|
Parameters
|
|
----------
|
|
v : str
|
|
Version string, e.g. "0.25.0"
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
|
|
|
|
Raises
|
|
------
|
|
InvalidVersion
|
|
If the version strings are not valid.
|
|
"""
|
|
import tornado
|
|
|
|
return is_version_less_than(tornado.version, v)
|
|
|
|
|
|
def is_url_from_allowed_origins(url: str) -> bool:
|
|
"""Return True if URL is from allowed origins (for CORS purpose).
|
|
|
|
Allowed origins:
|
|
1. localhost
|
|
2. The internal and external IP addresses of the machine where this
|
|
function was called from.
|
|
|
|
If `server.enableCORS` is False, this allows all origins.
|
|
"""
|
|
if not config.get_option("server.enableCORS"):
|
|
# Allow everything when CORS is disabled.
|
|
return True
|
|
|
|
hostname = url_util.get_hostname(url)
|
|
|
|
allowlisted_domains = [
|
|
url_util.get_hostname(origin) for origin in allowlisted_origins()
|
|
]
|
|
|
|
allowed_domains: list[str | None | Callable[[], str | None]] = [
|
|
# Check localhost first.
|
|
"localhost",
|
|
"0.0.0.0", # noqa: S104
|
|
"127.0.0.1",
|
|
# Try to avoid making unnecessary HTTP requests by checking if the user
|
|
# manually specified a server address.
|
|
_get_server_address_if_manually_set,
|
|
# Then try the options that depend on HTTP requests or opening sockets.
|
|
net_util.get_internal_ip,
|
|
net_util.get_external_ip,
|
|
*allowlisted_domains,
|
|
]
|
|
|
|
for allowed_domain in allowed_domains:
|
|
allowed_domain_str = (
|
|
allowed_domain() if callable(allowed_domain) else allowed_domain
|
|
)
|
|
|
|
if allowed_domain_str is None:
|
|
continue
|
|
|
|
if hostname == allowed_domain_str:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def get_cookie_secret() -> str:
|
|
"""Get the cookie secret.
|
|
|
|
If the user has not set a cookie secret, we generate a random one.
|
|
"""
|
|
cookie_secret: str = config.get_option("server.cookieSecret")
|
|
if secrets_singleton.load_if_toml_exists():
|
|
auth_section = secrets_singleton.get("auth")
|
|
if auth_section:
|
|
cookie_secret = auth_section.get("cookie_secret", cookie_secret)
|
|
return cookie_secret
|
|
|
|
|
|
def is_xsrf_enabled() -> bool:
|
|
csrf_enabled = config.get_option("server.enableXsrfProtection")
|
|
if not csrf_enabled and secrets_singleton.load_if_toml_exists():
|
|
auth_section = secrets_singleton.get("auth", None)
|
|
csrf_enabled = csrf_enabled or auth_section is not None
|
|
return cast("bool", csrf_enabled)
|
|
|
|
|
|
def _get_server_address_if_manually_set() -> str | None:
|
|
if config.is_manually_set("browser.serverAddress"):
|
|
return url_util.get_hostname(config.get_option("browser.serverAddress"))
|
|
return None
|
|
|
|
|
|
def make_url_path_regex(
|
|
*path: str,
|
|
trailing_slash: Literal["optional", "required", "prohibited"] = "optional",
|
|
) -> str:
|
|
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
|
|
filtered_paths = [x.strip("/") for x in path if x] # Filter out falsely components.
|
|
path_format = r"^/%s$"
|
|
if trailing_slash == "optional":
|
|
path_format = r"^/%s/?$"
|
|
elif trailing_slash == "required":
|
|
path_format = r"^/%s/$"
|
|
|
|
return path_format % "/".join(filtered_paths)
|
|
|
|
|
|
def get_url(host_ip: str) -> str:
|
|
"""Get the URL for any app served at the given host_ip.
|
|
|
|
Parameters
|
|
----------
|
|
host_ip : str
|
|
The IP address of the machine that is running the Streamlit Server.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The URL.
|
|
"""
|
|
protocol = "https" if config.get_option("server.sslCertFile") else "http"
|
|
|
|
port = _get_browser_address_bar_port()
|
|
base_path = config.get_option("server.baseUrlPath").strip("/")
|
|
|
|
if base_path:
|
|
base_path = "/" + base_path
|
|
|
|
host_ip = host_ip.strip("/")
|
|
return f"{protocol}://{host_ip}:{port}{base_path}"
|
|
|
|
|
|
def _get_browser_address_bar_port() -> int:
|
|
"""Get the app URL that will be shown in the browser's address bar.
|
|
|
|
That is, this is the port where static assets will be served from. In dev,
|
|
this is different from the URL that will be used to connect to the
|
|
server-browser websocket.
|
|
|
|
"""
|
|
if config.get_option("global.developmentMode"):
|
|
return DEVELOPMENT_PORT
|
|
return int(config.get_option("browser.serverPort"))
|
|
|
|
|
|
def emit_endpoint_deprecation_notice(handler: RequestHandler, new_path: str) -> None:
|
|
"""Emits the warning about deprecation of HTTP endpoint in the HTTP header."""
|
|
handler.set_header("Deprecation", True)
|
|
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
|
|
handler.set_header("Link", f'<{new_url}>; rel="alternate"')
|