Spaces:
Sleeping
Sleeping
from litellm.proxy._types import UserAPIKeyAuth | |
async def check_oauth2_token(token: str) -> UserAPIKeyAuth: | |
""" | |
Makes a request to the token info endpoint to validate the OAuth2 token. | |
Args: | |
token (str): The OAuth2 token to validate. | |
Returns: | |
Literal[True]: If the token is valid. | |
Raises: | |
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set. | |
""" | |
import os | |
import httpx | |
from litellm._logging import verbose_proxy_logger | |
from litellm.llms.custom_httpx.http_handler import ( | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.proxy._types import CommonProxyErrors | |
from litellm.proxy.proxy_server import premium_user | |
if premium_user is not True: | |
raise ValueError( | |
"Oauth2 token validation is only available for premium users" | |
+ CommonProxyErrors.not_premium_user.value | |
) | |
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token) | |
# Get the token info endpoint from environment variable | |
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") | |
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub") | |
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role") | |
user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id") | |
if not token_info_endpoint: | |
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") | |
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) | |
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} | |
try: | |
response = await client.get(token_info_endpoint, headers=headers) | |
# if it's a bad token we expect it to raise an HTTPStatusError | |
response.raise_for_status() | |
# If we get here, the request was successful | |
data = response.json() | |
verbose_proxy_logger.debug( | |
"Oauth2 token validation for token=%s, response from /token/info=%s", | |
token, | |
data, | |
) | |
# You might want to add additional checks here based on the response | |
# For example, checking if the token is expired or has the correct scope | |
user_id = data.get(user_id_field_name) | |
user_team_id = data.get(user_team_id_field_name) | |
user_role = data.get(user_role_field_name) | |
return UserAPIKeyAuth( | |
api_key=token, | |
team_id=user_team_id, | |
user_id=user_id, | |
user_role=user_role, | |
) | |
except httpx.HTTPStatusError as e: | |
# This will catch any 4xx or 5xx errors | |
raise ValueError(f"Oauth 2.0 Token validation failed: {e}") | |
except Exception as e: | |
# This will catch any other errors (like network issues) | |
raise ValueError(f"An error occurred during token validation: {e}") | |