461 lines
18 KiB
Python
461 lines
18 KiB
Python
![]() |
import datetime
|
||
|
import hashlib
|
||
|
import logging
|
||
|
import os
|
||
|
import time
|
||
|
import urllib.parse
|
||
|
import warnings
|
||
|
from dataclasses import dataclass
|
||
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||
|
|
||
|
from . import constants
|
||
|
from .hf_api import whoami
|
||
|
from .utils import experimental, get_token
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
import fastapi
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OAuthOrgInfo:
|
||
|
"""
|
||
|
Information about an organization linked to a user logged in with OAuth.
|
||
|
|
||
|
Attributes:
|
||
|
sub (`str`):
|
||
|
Unique identifier for the org. OpenID Connect field.
|
||
|
name (`str`):
|
||
|
The org's full name. OpenID Connect field.
|
||
|
preferred_username (`str`):
|
||
|
The org's username. OpenID Connect field.
|
||
|
picture (`str`):
|
||
|
The org's profile picture URL. OpenID Connect field.
|
||
|
is_enterprise (`bool`):
|
||
|
Whether the org is an enterprise org. Hugging Face field.
|
||
|
can_pay (`Optional[bool]`, *optional*):
|
||
|
Whether the org has a payment method set up. Hugging Face field.
|
||
|
role_in_org (`Optional[str]`, *optional*):
|
||
|
The user's role in the org. Hugging Face field.
|
||
|
security_restrictions (`Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*):
|
||
|
Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field.
|
||
|
"""
|
||
|
|
||
|
sub: str
|
||
|
name: str
|
||
|
preferred_username: str
|
||
|
picture: str
|
||
|
is_enterprise: bool
|
||
|
can_pay: Optional[bool] = None
|
||
|
role_in_org: Optional[str] = None
|
||
|
security_restrictions: Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]] = None
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OAuthUserInfo:
|
||
|
"""
|
||
|
Information about a user logged in with OAuth.
|
||
|
|
||
|
Attributes:
|
||
|
sub (`str`):
|
||
|
Unique identifier for the user, even in case of rename. OpenID Connect field.
|
||
|
name (`str`):
|
||
|
The user's full name. OpenID Connect field.
|
||
|
preferred_username (`str`):
|
||
|
The user's username. OpenID Connect field.
|
||
|
email_verified (`Optional[bool]`, *optional*):
|
||
|
Indicates if the user's email is verified. OpenID Connect field.
|
||
|
email (`Optional[str]`, *optional*):
|
||
|
The user's email address. OpenID Connect field.
|
||
|
picture (`str`):
|
||
|
The user's profile picture URL. OpenID Connect field.
|
||
|
profile (`str`):
|
||
|
The user's profile URL. OpenID Connect field.
|
||
|
website (`Optional[str]`, *optional*):
|
||
|
The user's website URL. OpenID Connect field.
|
||
|
is_pro (`bool`):
|
||
|
Whether the user is a pro user. Hugging Face field.
|
||
|
can_pay (`Optional[bool]`, *optional*):
|
||
|
Whether the user has a payment method set up. Hugging Face field.
|
||
|
orgs (`Optional[List[OrgInfo]]`, *optional*):
|
||
|
List of organizations the user is part of. Hugging Face field.
|
||
|
"""
|
||
|
|
||
|
sub: str
|
||
|
name: str
|
||
|
preferred_username: str
|
||
|
email_verified: Optional[bool]
|
||
|
email: Optional[str]
|
||
|
picture: str
|
||
|
profile: str
|
||
|
website: Optional[str]
|
||
|
is_pro: bool
|
||
|
can_pay: Optional[bool]
|
||
|
orgs: Optional[List[OAuthOrgInfo]]
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OAuthInfo:
|
||
|
"""
|
||
|
Information about the OAuth login.
|
||
|
|
||
|
Attributes:
|
||
|
access_token (`str`):
|
||
|
The access token.
|
||
|
access_token_expires_at (`datetime.datetime`):
|
||
|
The expiration date of the access token.
|
||
|
user_info ([`OAuthUserInfo`]):
|
||
|
The user information.
|
||
|
state (`str`, *optional*):
|
||
|
State passed to the OAuth provider in the original request to the OAuth provider.
|
||
|
scope (`str`):
|
||
|
Granted scope.
|
||
|
"""
|
||
|
|
||
|
access_token: str
|
||
|
access_token_expires_at: datetime.datetime
|
||
|
user_info: OAuthUserInfo
|
||
|
state: Optional[str]
|
||
|
scope: str
|
||
|
|
||
|
|
||
|
@experimental
|
||
|
def attach_huggingface_oauth(app: "fastapi.FastAPI", route_prefix: str = "/"):
|
||
|
"""
|
||
|
Add OAuth endpoints to a FastAPI app to enable OAuth login with Hugging Face.
|
||
|
|
||
|
How to use:
|
||
|
- Call this method on your FastAPI app to add the OAuth endpoints.
|
||
|
- Inside your route handlers, call `parse_huggingface_oauth(request)` to retrieve the OAuth info.
|
||
|
- If user is logged in, an [`OAuthInfo`] object is returned with the user's info. If not, `None` is returned.
|
||
|
- In your app, make sure to add links to `/oauth/huggingface/login` and `/oauth/huggingface/logout` for the user to log in and out.
|
||
|
|
||
|
Example:
|
||
|
```py
|
||
|
from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth
|
||
|
|
||
|
# Create a FastAPI app
|
||
|
app = FastAPI()
|
||
|
|
||
|
# Add OAuth endpoints to the FastAPI app
|
||
|
attach_huggingface_oauth(app)
|
||
|
|
||
|
# Add a route that greets the user if they are logged in
|
||
|
@app.get("/")
|
||
|
def greet_json(request: Request):
|
||
|
# Retrieve the OAuth info from the request
|
||
|
oauth_info = parse_huggingface_oauth(request) # e.g. OAuthInfo dataclass
|
||
|
if oauth_info is None:
|
||
|
return {"msg": "Not logged in!"}
|
||
|
return {"msg": f"Hello, {oauth_info.user_info.preferred_username}!"}
|
||
|
```
|
||
|
"""
|
||
|
# TODO: handle generic case (handling OAuth in a non-Space environment with custom dev values) (low priority)
|
||
|
|
||
|
# Add SessionMiddleware to the FastAPI app to store the OAuth info in the session.
|
||
|
# Session Middleware requires a secret key to sign the cookies. Let's use a hash
|
||
|
# of the OAuth secret key to make it unique to the Space + updated in case OAuth
|
||
|
# config gets updated. When ran locally, we use an empty string as a secret key.
|
||
|
try:
|
||
|
from starlette.middleware.sessions import SessionMiddleware
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add "
|
||
|
"`huggingface_hub[oauth]` to your requirements.txt file in order to install the required dependencies."
|
||
|
) from e
|
||
|
session_secret = (constants.OAUTH_CLIENT_SECRET or "") + "-v1"
|
||
|
app.add_middleware(
|
||
|
SessionMiddleware, # type: ignore[arg-type]
|
||
|
secret_key=hashlib.sha256(session_secret.encode()).hexdigest(),
|
||
|
same_site="none",
|
||
|
https_only=True,
|
||
|
) # type: ignore
|
||
|
|
||
|
# Add OAuth endpoints to the FastAPI app:
|
||
|
# - {route_prefix}/oauth/huggingface/login
|
||
|
# - {route_prefix}/oauth/huggingface/callback
|
||
|
# - {route_prefix}/oauth/huggingface/logout
|
||
|
# If the app is running in a Space, OAuth is enabled normally.
|
||
|
# Otherwise, we mock the endpoints to make the user log in with a fake user profile - without any calls to hf.co.
|
||
|
route_prefix = route_prefix.strip("/")
|
||
|
if os.getenv("SPACE_ID") is not None:
|
||
|
logger.info("OAuth is enabled in the Space. Adding OAuth routes.")
|
||
|
_add_oauth_routes(app, route_prefix=route_prefix)
|
||
|
else:
|
||
|
logger.info("App is not running in a Space. Adding mocked OAuth routes.")
|
||
|
_add_mocked_oauth_routes(app, route_prefix=route_prefix)
|
||
|
|
||
|
|
||
|
def parse_huggingface_oauth(request: "fastapi.Request") -> Optional[OAuthInfo]:
|
||
|
"""
|
||
|
Returns the information from a logged in user as a [`OAuthInfo`] object.
|
||
|
|
||
|
For flexibility and future-proofing, this method is very lax in its parsing and does not raise errors.
|
||
|
Missing fields are set to `None` without a warning.
|
||
|
|
||
|
Return `None`, if the user is not logged in (no info in session cookie).
|
||
|
|
||
|
See [`attach_huggingface_oauth`] for an example on how to use this method.
|
||
|
"""
|
||
|
if "oauth_info" not in request.session:
|
||
|
logger.debug("No OAuth info in session.")
|
||
|
return None
|
||
|
|
||
|
logger.debug("Parsing OAuth info from session.")
|
||
|
oauth_data = request.session["oauth_info"]
|
||
|
user_data = oauth_data.get("userinfo", {})
|
||
|
orgs_data = user_data.get("orgs", [])
|
||
|
|
||
|
orgs = (
|
||
|
[
|
||
|
OAuthOrgInfo(
|
||
|
sub=org.get("sub"),
|
||
|
name=org.get("name"),
|
||
|
preferred_username=org.get("preferred_username"),
|
||
|
picture=org.get("picture"),
|
||
|
is_enterprise=org.get("isEnterprise"),
|
||
|
can_pay=org.get("canPay"),
|
||
|
role_in_org=org.get("roleInOrg"),
|
||
|
security_restrictions=org.get("securityRestrictions"),
|
||
|
)
|
||
|
for org in orgs_data
|
||
|
]
|
||
|
if orgs_data
|
||
|
else None
|
||
|
)
|
||
|
|
||
|
user_info = OAuthUserInfo(
|
||
|
sub=user_data.get("sub"),
|
||
|
name=user_data.get("name"),
|
||
|
preferred_username=user_data.get("preferred_username"),
|
||
|
email_verified=user_data.get("email_verified"),
|
||
|
email=user_data.get("email"),
|
||
|
picture=user_data.get("picture"),
|
||
|
profile=user_data.get("profile"),
|
||
|
website=user_data.get("website"),
|
||
|
is_pro=user_data.get("isPro"),
|
||
|
can_pay=user_data.get("canPay"),
|
||
|
orgs=orgs,
|
||
|
)
|
||
|
|
||
|
return OAuthInfo(
|
||
|
access_token=oauth_data.get("access_token"),
|
||
|
access_token_expires_at=datetime.datetime.fromtimestamp(oauth_data.get("expires_at")),
|
||
|
user_info=user_info,
|
||
|
state=oauth_data.get("state"),
|
||
|
scope=oauth_data.get("scope"),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _add_oauth_routes(app: "fastapi.FastAPI", route_prefix: str) -> None:
|
||
|
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
|
||
|
try:
|
||
|
import fastapi
|
||
|
from authlib.integrations.base_client.errors import MismatchingStateError
|
||
|
from authlib.integrations.starlette_client import OAuth
|
||
|
from fastapi.responses import RedirectResponse
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add "
|
||
|
"`huggingface_hub[oauth]` to your requirements.txt file."
|
||
|
) from e
|
||
|
|
||
|
# Check environment variables
|
||
|
msg = (
|
||
|
"OAuth is required but '{}' environment variable is not set. Make sure you've enabled OAuth in your Space by"
|
||
|
" setting `hf_oauth: true` in the Space metadata."
|
||
|
)
|
||
|
if constants.OAUTH_CLIENT_ID is None:
|
||
|
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
|
||
|
if constants.OAUTH_CLIENT_SECRET is None:
|
||
|
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
|
||
|
if constants.OAUTH_SCOPES is None:
|
||
|
raise ValueError(msg.format("OAUTH_SCOPES"))
|
||
|
if constants.OPENID_PROVIDER_URL is None:
|
||
|
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
|
||
|
|
||
|
# Register OAuth server
|
||
|
oauth = OAuth()
|
||
|
oauth.register(
|
||
|
name="huggingface",
|
||
|
client_id=constants.OAUTH_CLIENT_ID,
|
||
|
client_secret=constants.OAUTH_CLIENT_SECRET,
|
||
|
client_kwargs={"scope": constants.OAUTH_SCOPES},
|
||
|
server_metadata_url=constants.OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
|
||
|
)
|
||
|
|
||
|
login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix)
|
||
|
|
||
|
# Register OAuth endpoints
|
||
|
@app.get(login_uri)
|
||
|
async def oauth_login(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Endpoint that redirects to HF OAuth page."""
|
||
|
redirect_uri = _generate_redirect_uri(request)
|
||
|
return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
|
||
|
|
||
|
@app.get(callback_uri)
|
||
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Endpoint that handles the OAuth callback."""
|
||
|
try:
|
||
|
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
|
||
|
except MismatchingStateError:
|
||
|
# Parse query params
|
||
|
nb_redirects = int(request.query_params.get("_nb_redirects", 0))
|
||
|
target_url = request.query_params.get("_target_url")
|
||
|
|
||
|
# Build redirect URI with the same query params as before and bump nb_redirects count
|
||
|
query_params: Dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1}
|
||
|
if target_url:
|
||
|
query_params["_target_url"] = target_url
|
||
|
|
||
|
redirect_uri = f"{login_uri}?{urllib.parse.urlencode(query_params)}"
|
||
|
|
||
|
# If the user is redirected more than 3 times, it is very likely that the cookie is not working properly.
|
||
|
# (e.g. browser is blocking third-party cookies in iframe). In this case, redirect the user in the
|
||
|
# non-iframe view.
|
||
|
if nb_redirects > constants.OAUTH_MAX_REDIRECTS:
|
||
|
host = os.environ.get("SPACE_HOST")
|
||
|
if host is None: # cannot happen in a Space
|
||
|
raise RuntimeError(
|
||
|
"App is not running in a Space (SPACE_HOST environment variable is not set). Cannot redirect to non-iframe view."
|
||
|
) from None
|
||
|
host_url = "https://" + host.rstrip("/")
|
||
|
return RedirectResponse(host_url + redirect_uri)
|
||
|
|
||
|
# Redirect the user to the login page again
|
||
|
return RedirectResponse(redirect_uri)
|
||
|
|
||
|
# OAuth login worked => store the user info in the session and redirect
|
||
|
logger.debug("Successfully logged in with OAuth. Storing user info in session.")
|
||
|
request.session["oauth_info"] = oauth_info
|
||
|
return RedirectResponse(_get_redirect_target(request))
|
||
|
|
||
|
@app.get(logout_uri)
|
||
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Endpoint that logs out the user (e.g. delete info from cookie session)."""
|
||
|
logger.debug("Logged out with OAuth. Removing user info from session.")
|
||
|
request.session.pop("oauth_info", None)
|
||
|
return RedirectResponse(_get_redirect_target(request))
|
||
|
|
||
|
|
||
|
def _add_mocked_oauth_routes(app: "fastapi.FastAPI", route_prefix: str = "/") -> None:
|
||
|
"""Add fake oauth routes if app is run locally and OAuth is enabled.
|
||
|
|
||
|
Using OAuth will have the same behavior as in a Space but instead of authenticating with HF, a mocked user profile
|
||
|
is added to the session.
|
||
|
"""
|
||
|
try:
|
||
|
import fastapi
|
||
|
from fastapi.responses import RedirectResponse
|
||
|
from starlette.datastructures import URL
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add "
|
||
|
"`huggingface_hub[oauth]` to your requirements.txt file."
|
||
|
) from e
|
||
|
|
||
|
warnings.warn(
|
||
|
"OAuth is not supported outside of a Space environment. To help you debug your app locally, the oauth endpoints"
|
||
|
" are mocked to return your profile and token. To make it work, your machine must be logged in to Huggingface."
|
||
|
)
|
||
|
mocked_oauth_info = _get_mocked_oauth_info()
|
||
|
|
||
|
login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix)
|
||
|
|
||
|
# Define OAuth routes
|
||
|
@app.get(login_uri)
|
||
|
async def oauth_login(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Fake endpoint that redirects to HF OAuth page."""
|
||
|
# Define target (where to redirect after login)
|
||
|
redirect_uri = _generate_redirect_uri(request)
|
||
|
return RedirectResponse(callback_uri + "?" + urllib.parse.urlencode({"_target_url": redirect_uri}))
|
||
|
|
||
|
@app.get(callback_uri)
|
||
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Endpoint that handles the OAuth callback."""
|
||
|
request.session["oauth_info"] = mocked_oauth_info
|
||
|
return RedirectResponse(_get_redirect_target(request))
|
||
|
|
||
|
@app.get(logout_uri)
|
||
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
|
||
|
"""Endpoint that logs out the user (e.g. delete cookie session)."""
|
||
|
request.session.pop("oauth_info", None)
|
||
|
logout_url = URL("/").include_query_params(**request.query_params)
|
||
|
return RedirectResponse(url=logout_url, status_code=302) # see https://github.com/gradio-app/gradio/pull/9659
|
||
|
|
||
|
|
||
|
def _generate_redirect_uri(request: "fastapi.Request") -> str:
|
||
|
if "_target_url" in request.query_params:
|
||
|
# if `_target_url` already in query params => respect it
|
||
|
target = request.query_params["_target_url"]
|
||
|
else:
|
||
|
# otherwise => keep query params
|
||
|
target = "/?" + urllib.parse.urlencode(request.query_params)
|
||
|
|
||
|
redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target)
|
||
|
redirect_uri_as_str = str(redirect_uri)
|
||
|
if redirect_uri.netloc.endswith(".hf.space"):
|
||
|
# In Space, FastAPI redirect as http but we want https
|
||
|
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
|
||
|
return redirect_uri_as_str
|
||
|
|
||
|
|
||
|
def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") -> str:
|
||
|
return request.query_params.get("_target_url", default_target)
|
||
|
|
||
|
|
||
|
def _get_mocked_oauth_info() -> Dict:
|
||
|
token = get_token()
|
||
|
if token is None:
|
||
|
raise ValueError(
|
||
|
"Your machine must be logged in to HF to debug an OAuth app locally. Please"
|
||
|
" run `hf auth login` or set `HF_TOKEN` as environment variable "
|
||
|
"with one of your access token. You can generate a new token in your "
|
||
|
"settings page (https://huggingface.co/settings/tokens)."
|
||
|
)
|
||
|
|
||
|
user = whoami()
|
||
|
if user["type"] != "user":
|
||
|
raise ValueError(
|
||
|
"Your machine is not logged in with a personal account. Please use a "
|
||
|
"personal access token. You can generate a new token in your settings page"
|
||
|
" (https://huggingface.co/settings/tokens)."
|
||
|
)
|
||
|
|
||
|
return {
|
||
|
"access_token": token,
|
||
|
"token_type": "bearer",
|
||
|
"expires_in": 8 * 60 * 60, # 8 hours
|
||
|
"id_token": "FOOBAR",
|
||
|
"scope": "openid profile",
|
||
|
"refresh_token": "hf_oauth__refresh_token",
|
||
|
"expires_at": int(time.time()) + 8 * 60 * 60, # 8 hours
|
||
|
"userinfo": {
|
||
|
"sub": "0123456789",
|
||
|
"name": user["fullname"],
|
||
|
"preferred_username": user["name"],
|
||
|
"profile": f"https://huggingface.co/{user['name']}",
|
||
|
"picture": user["avatarUrl"],
|
||
|
"website": "",
|
||
|
"aud": "00000000-0000-0000-0000-000000000000",
|
||
|
"auth_time": 1691672844,
|
||
|
"nonce": "aaaaaaaaaaaaaaaaaaa",
|
||
|
"iat": 1691672844,
|
||
|
"exp": 1691676444,
|
||
|
"iss": "https://huggingface.co",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
|
||
|
def _get_oauth_uris(route_prefix: str = "/") -> Tuple[str, str, str]:
|
||
|
route_prefix = route_prefix.strip("/")
|
||
|
if route_prefix:
|
||
|
route_prefix = f"/{route_prefix}"
|
||
|
return (
|
||
|
f"{route_prefix}/oauth/huggingface/login",
|
||
|
f"{route_prefix}/oauth/huggingface/callback",
|
||
|
f"{route_prefix}/oauth/huggingface/logout",
|
||
|
)
|