File size: 4,106 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""Common interfaces for handling queries to prediction endpoints."""
from __future__ import annotations
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
import requests
from folding_studio.config import API_URL, FOLDING_API_KEY, REQUEST_TIMEOUT
from folding_studio.query import Query
from folding_studio.utils.gcp import TokenManager, download_file_from_signed_url
vertexi_ai_forwarding_url = API_URL + "predictWithVertexEndpoints"
class Client:
"""Prediction client to send the request with."""
def __init__(
self, api_key: str | None = None, token_manager: TokenManager | None = None
) -> None:
self.api_key = api_key
self.token_manager = token_manager
@classmethod
def authenticate(cls) -> Client:
"""Instantiates an authenticated client."""
if FOLDING_API_KEY:
return cls.from_api_key(api_key=FOLDING_API_KEY)
else:
return cls.from_jwt()
@classmethod
def from_api_key(cls, api_key: str) -> Client:
"""Instantiates a Client object using an API key."""
return cls(api_key=api_key, token_manager=None)
@classmethod
def from_jwt(cls) -> Client:
"""Instantiates a Client object using a Google Cloud JWT token."""
token_manager = TokenManager()
return cls(token_manager=token_manager)
def send_request(
self, query: Query, project_code: str, timeout: int = REQUEST_TIMEOUT
) -> Response:
"""Sends a request to the endpoint, handling authentication and errors
Args:
query (Query): Folding query to send.
project_code (str): Project code to associate to the query.
timeout (int, optional): Request timeout. Defaults to REQUEST_TIMEOUT.
Returns:
Response: Prediction endpoint response.
"""
headers = {}
if self.api_key:
headers["X-API-Key"] = self.api_key
elif self.token_manager:
headers["Authorization"] = f"Bearer {self.token_manager.get_token()}"
params = {"project_code": project_code, "model": query.MODEL}
try:
response = requests.post(
url=vertexi_ai_forwarding_url,
params=params,
json=query.payload,
headers=headers,
timeout=timeout,
)
response.raise_for_status() # Raise HTTPError for bad responses (4xx, 5xx)
json_response = response.json()
return Response(
output_signed_url=json_response["signed_url"],
confidence_data=json_response["confidence_data"],
)
except requests.exceptions.RequestException as e:
logging.error(f"Error sending request: {e}")
raise e
class Response:
"""Class to handle the endpoints JSON responses."""
def __init__(self, output_signed_url: str, confidence_data: Dict[str, Any]) -> None:
self.output_signed_url = output_signed_url
self._confidence_data = confidence_data
self.unzip_folder_name = ""
def download_results(
self, output_dir: Path, *, force: bool = False, unzip: bool = False
) -> None:
"""Downloads and optionally unzips the result file.
Args:
output_dir (Path): Path where the file will be saved.
force (bool): Overwrite existing file if True.
unzip (bool): Extract contents if the file is a zip.
Raises:
typer.Exit: If an error occurs or file already exists without `force`.
"""
output_path = (
output_dir / f"results_{datetime.now().strftime('%Y%m%d%H%M%S')}.zip"
)
download_file_from_signed_url(
self.output_signed_url,
output_path,
force=force,
unzip=unzip,
unzip_dir=output_dir,
)
@property
def confidence_data(self) -> Dict[str, Any]:
"""Prediction confidence data."""
return self._confidence_data
|