team-10/env/Lib/site-packages/streamlit/auth_util.py

220 lines
8.4 KiB
Python
Raw Permalink Normal View History

2025-08-02 07:34:44 +02:00
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, TypedDict, cast
from streamlit import config
from streamlit.errors import StreamlitAuthError
from streamlit.runtime.secrets import AttrDict, secrets_singleton
if TYPE_CHECKING:
class ProviderTokenPayload(TypedDict):
provider: str
exp: int
class AuthCache:
"""Simple cache implementation for storing info required for Authlib."""
def __init__(self) -> None:
self.cache: dict[str, Any] = {}
def get(self, key: str) -> Any:
return self.cache.get(key)
# for set method, we are follow the same signature used in Authlib
# the expires_in is not used in our case
def set(self, key: str, value: Any, expires_in: int | None = None) -> None: # noqa: ARG002
self.cache[key] = value
def get_dict(self) -> dict[str, Any]:
return self.cache
def delete(self, key: str) -> None:
self.cache.pop(key, None)
def is_authlib_installed() -> bool:
"""Check if Authlib is installed."""
try:
import authlib
authlib_version = authlib.__version__
authlib_version_tuple = tuple(map(int, authlib_version.split(".")))
if authlib_version_tuple < (1, 3, 2):
return False
except (ImportError, ModuleNotFoundError):
return False
return True
def get_signing_secret() -> str:
"""Get the cookie signing secret from the configuration or secrets.toml."""
signing_secret: str = config.get_option("server.cookieSecret")
if secrets_singleton.load_if_toml_exists():
auth_section = secrets_singleton.get("auth")
if auth_section:
signing_secret = auth_section.get("cookie_secret", signing_secret)
return signing_secret
def get_secrets_auth_section() -> AttrDict:
auth_section = AttrDict({})
"""Get the 'auth' section of the secrets.toml."""
if secrets_singleton.load_if_toml_exists():
auth_section = cast("AttrDict", secrets_singleton.get("auth"))
return auth_section
def encode_provider_token(provider: str) -> str:
"""Returns a signed JWT token with the provider and expiration time."""
try:
from authlib.jose import jwt
except ImportError:
raise StreamlitAuthError(
"""To use authentication features, you need to install Authlib>=1.3.2, e.g. via `pip install Authlib`."""
) from None
header = {"alg": "HS256"}
payload = {
"provider": provider,
"exp": datetime.now(timezone.utc) + timedelta(minutes=2),
}
provider_token: bytes = jwt.encode(header, payload, get_signing_secret())
# JWT token is a byte string, so we need to decode it to a URL compatible string
return provider_token.decode("latin-1")
def decode_provider_token(provider_token: str) -> ProviderTokenPayload:
"""Decode the JWT token and validate the claims."""
try:
from authlib.jose import JoseError, JWTClaims, jwt
except ImportError:
raise StreamlitAuthError(
"""To use authentication features, you need to install Authlib>=1.3.2, e.g. via `pip install Authlib`."""
) from None
# Our JWT token is short-lived (2 minutes), so we check here that it contains
# the 'exp' (and it is not expired), and 'provider' field exists.
claim_options = {"exp": {"essential": True}, "provider": {"essential": True}}
try:
payload: JWTClaims = jwt.decode(
provider_token, get_signing_secret(), claims_options=claim_options
)
payload.validate()
except JoseError as e:
raise StreamlitAuthError(f"Error decoding provider token: {e}") from None
return cast("ProviderTokenPayload", payload)
def generate_default_provider_section(auth_section: AttrDict) -> dict[str, Any]:
"""Generate a default provider section for the 'auth' section of secrets.toml."""
default_provider_section = {}
if auth_section.get("client_id"):
default_provider_section["client_id"] = auth_section.get("client_id")
if auth_section.get("client_secret"):
default_provider_section["client_secret"] = auth_section.get("client_secret")
if auth_section.get("server_metadata_url"):
default_provider_section["server_metadata_url"] = auth_section.get(
"server_metadata_url"
)
if auth_section.get("client_kwargs"):
default_provider_section["client_kwargs"] = cast(
"AttrDict", auth_section.get("client_kwargs", AttrDict({}))
).to_dict()
return default_provider_section
def validate_auth_credentials(provider: str) -> None:
"""Validate the general auth credentials and auth credentials for the given
provider.
"""
if not secrets_singleton.load_if_toml_exists():
raise StreamlitAuthError(
"""To use authentication features you need to configure credentials for at
least one authentication provider in `.streamlit/secrets.toml`."""
)
auth_section = secrets_singleton.get("auth")
if auth_section is None:
raise StreamlitAuthError(
"""To use authentication features you need to configure credentials for at
least one authentication provider in `.streamlit/secrets.toml`."""
)
if "redirect_uri" not in auth_section:
raise StreamlitAuthError(
"""Authentication credentials in `.streamlit/secrets.toml` are missing the
"redirect_uri" key. Please check your configuration."""
)
if "cookie_secret" not in auth_section:
raise StreamlitAuthError(
"""Authentication credentials in `.streamlit/secrets.toml` are missing the
"cookie_secret" key. Please check your configuration."""
)
provider_section = auth_section.get(provider)
# TODO(kajarenc): Revisit this check later when investigating the ability
# TODO(kajarenc): to add "_" to the provider name.
if "_" in provider:
raise StreamlitAuthError(
f'Auth provider name "{provider}" contains an underscore. '
f"Please use a provider name without underscores."
)
if provider_section is None and provider == "default":
provider_section = generate_default_provider_section(auth_section)
if provider_section is None:
if provider == "default":
raise StreamlitAuthError(
"""Authentication credentials in `.streamlit/secrets.toml` are missing for
the default authentication provider. Please check your configuration."""
)
raise StreamlitAuthError(
f"Authentication credentials in `.streamlit/secrets.toml` are missing for "
f'the authentication provider "{provider}". Please check your '
f"configuration."
)
if not isinstance(provider_section, Mapping):
raise StreamlitAuthError(
f"Authentication credentials in `.streamlit/secrets.toml` for the "
f'authentication provider "{provider}" must be valid TOML. Please check '
f"your configuration."
)
required_keys = ["client_id", "client_secret", "server_metadata_url"]
missing_keys = [key for key in required_keys if key not in provider_section]
if missing_keys:
if provider == "default":
raise StreamlitAuthError(
"Authentication credentials in `.streamlit/secrets.toml` for the "
f"default authentication provider are missing the following keys: "
f"{missing_keys}. Please check your configuration."
)
raise StreamlitAuthError(
"Authentication credentials in `.streamlit/secrets.toml` for the "
f'authentication provider "{provider}" are missing the following keys: '
f"{missing_keys}. Please check your configuration."
)