# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # mypy: disable-error-code="no-untyped-call" from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, cast from authlib.integrations.base_client import ( BaseApp, BaseOAuth, OAuth2Mixin, OAuthError, OpenIDMixin, ) from authlib.integrations.requests_client import ( OAuth2Session, ) from streamlit.web.server.authlib_tornado_integration import TornadoIntegration if TYPE_CHECKING: import tornado.web from streamlit.auth_util import AuthCache class TornadoOAuth2App(OAuth2Mixin, OpenIDMixin, BaseApp): client_cls = OAuth2Session def load_server_metadata(self) -> dict[str, Any]: """We enforce S256 code challenge method if it is supported by the server.""" result = cast("dict[str, Any]", super().load_server_metadata()) if "S256" in result.get("code_challenge_methods_supported", []): self.client_kwargs["code_challenge_method"] = "S256" return result def authorize_redirect( self, request_handler: tornado.web.RequestHandler, redirect_uri: Any = None, **kwargs: Any, ) -> None: """Create a HTTP Redirect for Authorization Endpoint. :param request_handler: HTTP request instance from Tornado. :param redirect_uri: Callback or redirect URI for authorization. :param kwargs: Extra parameters to include. :return: A HTTP redirect response. """ auth_context = self.create_authorization_url(redirect_uri, **kwargs) self._save_authorize_data(redirect_uri=redirect_uri, **auth_context) request_handler.redirect(auth_context["url"], status=302) def authorize_access_token( self, request_handler: tornado.web.RequestHandler, **kwargs: Any ) -> dict[str, Any]: """ :param request_handler: HTTP request instance from Tornado. :return: A token dict. """ error = request_handler.get_argument("error", None) if error: description = request_handler.get_argument("error_description", None) raise OAuthError(error=error, description=description) params = { "code": request_handler.get_argument("code"), "state": request_handler.get_argument("state"), } session = None claims_options = kwargs.pop("claims_options", None) state_data = self.framework.get_state_data(session, params.get("state")) self.framework.clear_state_data(session, params.get("state")) params = self._format_state_params(state_data, params) # type: ignore[attr-defined] token = self.fetch_access_token(**params, **kwargs) if "id_token" in token and "nonce" in state_data: userinfo = self.parse_id_token( token, nonce=state_data["nonce"], claims_options=claims_options ) token = {**token, "userinfo": userinfo} return cast("dict[str, Any]", token) def _save_authorize_data(self, **kwargs: Any) -> None: """Authlib underlying uses the concept of "session" to store state data. In Tornado, we don't have a session, so we use the framework's cache option. """ state = kwargs.pop("state", None) if state: session = None self.framework.set_state_data(session, state, kwargs) else: raise RuntimeError("Missing state value") class TornadoOAuth(BaseOAuth): oauth2_client_cls = TornadoOAuth2App framework_integration_cls = TornadoIntegration def __init__( self, config: dict[str, Any] | None = None, cache: AuthCache | None = None, fetch_token: Callable[[dict[str, Any]], dict[str, Any]] | None = None, update_token: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ): super().__init__( cache=cache, fetch_token=fetch_token, update_token=update_token ) self.config = config