"""Common interfaces for handling queries to prediction endpoints.""" from __future__ import annotations import logging from datetime import datetime from pathlib import Path from typing import Any, Dict import requests from folding_studio.config import API_URL, FOLDING_API_KEY, REQUEST_TIMEOUT from folding_studio.query import Query from folding_studio.utils.gcp import TokenManager, download_file_from_signed_url vertexi_ai_forwarding_url = API_URL + "predictWithVertexEndpoints" class Client: """Prediction client to send the request with.""" def __init__( self, api_key: str | None = None, token_manager: TokenManager | None = None ) -> None: self.api_key = api_key self.token_manager = token_manager @classmethod def authenticate(cls) -> Client: """Instantiates an authenticated client.""" if FOLDING_API_KEY: return cls.from_api_key(api_key=FOLDING_API_KEY) else: return cls.from_jwt() @classmethod def from_api_key(cls, api_key: str) -> Client: """Instantiates a Client object using an API key.""" return cls(api_key=api_key, token_manager=None) @classmethod def from_jwt(cls) -> Client: """Instantiates a Client object using a Google Cloud JWT token.""" token_manager = TokenManager() return cls(token_manager=token_manager) def send_request( self, query: Query, project_code: str, timeout: int = REQUEST_TIMEOUT ) -> Response: """Sends a request to the endpoint, handling authentication and errors Args: query (Query): Folding query to send. project_code (str): Project code to associate to the query. timeout (int, optional): Request timeout. Defaults to REQUEST_TIMEOUT. Returns: Response: Prediction endpoint response. """ headers = {} if self.api_key: headers["X-API-Key"] = self.api_key elif self.token_manager: headers["Authorization"] = f"Bearer {self.token_manager.get_token()}" params = {"project_code": project_code, "model": query.MODEL} try: response = requests.post( url=vertexi_ai_forwarding_url, params=params, json=query.payload, headers=headers, timeout=timeout, ) response.raise_for_status() # Raise HTTPError for bad responses (4xx, 5xx) json_response = response.json() return Response( output_signed_url=json_response["signed_url"], confidence_data=json_response["confidence_data"], ) except requests.exceptions.RequestException as e: logging.error(f"Error sending request: {e}") raise e class Response: """Class to handle the endpoints JSON responses.""" def __init__(self, output_signed_url: str, confidence_data: Dict[str, Any]) -> None: self.output_signed_url = output_signed_url self._confidence_data = confidence_data self.unzip_folder_name = "" def download_results( self, output_dir: Path, *, force: bool = False, unzip: bool = False ) -> None: """Downloads and optionally unzips the result file. Args: output_dir (Path): Path where the file will be saved. force (bool): Overwrite existing file if True. unzip (bool): Extract contents if the file is a zip. Raises: typer.Exit: If an error occurs or file already exists without `force`. """ output_path = ( output_dir / f"results_{datetime.now().strftime('%Y%m%d%H%M%S')}.zip" ) download_file_from_signed_url( self.output_signed_url, output_path, force=force, unzip=unzip, unzip_dir=output_dir, ) @property def confidence_data(self) -> Dict[str, Any]: """Prediction confidence data.""" return self._confidence_data