File size: 7,500 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional
import requests
from .. import constants
from . import get_session, hf_raise_for_status, validate_hf_hub_args
class XetTokenType(str, Enum):
READ = "read"
WRITE = "write"
@dataclass(frozen=True)
class XetFileData:
file_hash: str
refresh_route: str
@dataclass(frozen=True)
class XetConnectionInfo:
access_token: str
expiration_unix_epoch: int
endpoint: str
def parse_xet_file_data_from_response(response: requests.Response) -> Optional[XetFileData]:
"""
Parse XET file metadata from an HTTP response.
This function extracts XET file metadata from the HTTP headers or HTTP links
of a given response object. If the required metadata is not found, it returns `None`.
Args:
response (`requests.Response`):
The HTTP response object containing headers dict and links dict to extract the XET metadata from.
Returns:
`Optional[XetFileData]`:
An instance of `XetFileData` containing the file hash and refresh route if the metadata
is found. Returns `None` if the required metadata is missing.
"""
if response is None:
return None
try:
file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
else:
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
except KeyError:
return None
return XetFileData(
file_hash=file_hash,
refresh_route=refresh_route,
)
def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[XetConnectionInfo]:
"""
Parse XET connection info from the HTTP headers or return None if not found.
Args:
headers (`Dict`):
HTTP headers to extract the XET metadata from.
Returns:
`XetConnectionInfo` or `None`:
The information needed to connect to the XET storage service.
Returns `None` if the headers do not contain the XET connection info.
"""
try:
endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
except (KeyError, ValueError, TypeError):
return None
return XetConnectionInfo(
endpoint=endpoint,
access_token=access_token,
expiration_unix_epoch=expiration_unix_epoch,
)
@validate_hf_hub_args
def refresh_xet_connection_info(
*,
file_data: XetFileData,
headers: Dict[str, str],
endpoint: Optional[str] = None,
) -> XetConnectionInfo:
"""
Utilizes the information in the parsed metadata to request the Hub xet connection information.
This includes the access token, expiration, and XET service URL.
Args:
file_data: (`XetFileData`):
The file data needed to refresh the xet connection information.
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
endpoint (`str`, `optional`):
The endpoint to use for the request. Defaults to the Hub endpoint.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
if file_data.refresh_route is None:
raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
# TODO: An upcoming version of hub will prepend the endpoint to the refresh route in
# the headers. Once that's deployed we can call fetch on the refresh route directly.
url = file_data.refresh_route
if url.startswith("/"):
url = f"{endpoint}{url}"
return _fetch_xet_connection_info_with_url(url, headers)
@validate_hf_hub_args
def fetch_xet_connection_info_from_repo_info(
*,
token_type: XetTokenType,
repo_id: str,
repo_type: str,
revision: Optional[str] = None,
headers: Dict[str, str],
endpoint: Optional[str] = None,
params: Optional[Dict[str, str]] = None,
) -> XetConnectionInfo:
"""
Uses the repo info to request a xet access token from Hub.
Args:
token_type (`XetTokenType`):
Type of the token to request: `"read"` or `"write"`.
repo_id (`str`):
A namespace (user or an organization) and a repo name separated by a `/`.
repo_type (`str`):
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
revision (`str`, `optional`):
The revision of the repo to get the token for.
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
endpoint (`str`, `optional`):
The endpoint to use for the request. Defaults to the Hub endpoint.
params (`Dict[str, str]`, `optional`):
Additional parameters to pass with the request.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}"
return _fetch_xet_connection_info_with_url(url, headers, params)
@validate_hf_hub_args
def _fetch_xet_connection_info_with_url(
url: str,
headers: Dict[str, str],
params: Optional[Dict[str, str]] = None,
) -> XetConnectionInfo:
"""
Requests the xet connection info from the supplied URL. This includes the
access token, expiration time, and endpoint to use for the xet storage service.
Args:
url: (`str`):
The access token endpoint URL.
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
params (`Dict[str, str]`, `optional`):
Additional parameters to pass with the request.
Returns:
`XetConnectionInfo`:
The connection information needed to make the request to the xet storage service.
Raises:
[`~utils.HfHubHTTPError`]
If the Hub API returned an error.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the Hub API response is improperly formatted.
"""
resp = get_session().get(headers=headers, url=url, params=params)
hf_raise_for_status(resp)
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
if metadata is None:
raise ValueError("Xet headers have not been correctly set by the server.")
return metadata
|