|
"""Common interfaces for handling queries to prediction endpoints.""" |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Any, Dict |
|
|
|
import requests |
|
|
|
from folding_studio.config import API_URL, FOLDING_API_KEY, REQUEST_TIMEOUT |
|
from folding_studio.query import Query |
|
from folding_studio.utils.gcp import TokenManager, download_file_from_signed_url |
|
|
|
vertexi_ai_forwarding_url = API_URL + "predictWithVertexEndpoints" |
|
|
|
|
|
class Client: |
|
"""Prediction client to send the request with.""" |
|
|
|
def __init__( |
|
self, api_key: str | None = None, token_manager: TokenManager | None = None |
|
) -> None: |
|
self.api_key = api_key |
|
self.token_manager = token_manager |
|
|
|
@classmethod |
|
def authenticate(cls) -> Client: |
|
"""Instantiates an authenticated client.""" |
|
if FOLDING_API_KEY: |
|
return cls.from_api_key(api_key=FOLDING_API_KEY) |
|
else: |
|
return cls.from_jwt() |
|
|
|
@classmethod |
|
def from_api_key(cls, api_key: str) -> Client: |
|
"""Instantiates a Client object using an API key.""" |
|
return cls(api_key=api_key, token_manager=None) |
|
|
|
@classmethod |
|
def from_jwt(cls) -> Client: |
|
"""Instantiates a Client object using a Google Cloud JWT token.""" |
|
token_manager = TokenManager() |
|
return cls(token_manager=token_manager) |
|
|
|
def send_request( |
|
self, query: Query, project_code: str, timeout: int = REQUEST_TIMEOUT |
|
) -> Response: |
|
"""Sends a request to the endpoint, handling authentication and errors |
|
|
|
Args: |
|
query (Query): Folding query to send. |
|
project_code (str): Project code to associate to the query. |
|
timeout (int, optional): Request timeout. Defaults to REQUEST_TIMEOUT. |
|
|
|
Returns: |
|
Response: Prediction endpoint response. |
|
""" |
|
headers = {} |
|
if self.api_key: |
|
headers["X-API-Key"] = self.api_key |
|
elif self.token_manager: |
|
headers["Authorization"] = f"Bearer {self.token_manager.get_token()}" |
|
params = {"project_code": project_code, "model": query.MODEL} |
|
|
|
try: |
|
response = requests.post( |
|
url=vertexi_ai_forwarding_url, |
|
params=params, |
|
json=query.payload, |
|
headers=headers, |
|
timeout=timeout, |
|
) |
|
response.raise_for_status() |
|
json_response = response.json() |
|
|
|
return Response( |
|
output_signed_url=json_response["signed_url"], |
|
confidence_data=json_response["confidence_data"], |
|
) |
|
except requests.exceptions.RequestException as e: |
|
logging.error(f"Error sending request: {e}") |
|
raise e |
|
|
|
|
|
class Response: |
|
"""Class to handle the endpoints JSON responses.""" |
|
|
|
def __init__(self, output_signed_url: str, confidence_data: Dict[str, Any]) -> None: |
|
self.output_signed_url = output_signed_url |
|
self._confidence_data = confidence_data |
|
self.unzip_folder_name = "" |
|
|
|
def download_results( |
|
self, output_dir: Path, *, force: bool = False, unzip: bool = False |
|
) -> None: |
|
"""Downloads and optionally unzips the result file. |
|
|
|
Args: |
|
output_dir (Path): Path where the file will be saved. |
|
force (bool): Overwrite existing file if True. |
|
unzip (bool): Extract contents if the file is a zip. |
|
|
|
Raises: |
|
typer.Exit: If an error occurs or file already exists without `force`. |
|
""" |
|
output_path = ( |
|
output_dir / f"results_{datetime.now().strftime('%Y%m%d%H%M%S')}.zip" |
|
) |
|
download_file_from_signed_url( |
|
self.output_signed_url, |
|
output_path, |
|
force=force, |
|
unzip=unzip, |
|
unzip_dir=output_dir, |
|
) |
|
|
|
@property |
|
def confidence_data(self) -> Dict[str, Any]: |
|
"""Prediction confidence data.""" |
|
return self._confidence_data |
|
|