File size: 4,106 Bytes
44459bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Common interfaces for handling queries to prediction endpoints."""

from __future__ import annotations

import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict

import requests

from folding_studio.config import API_URL, FOLDING_API_KEY, REQUEST_TIMEOUT
from folding_studio.query import Query
from folding_studio.utils.gcp import TokenManager, download_file_from_signed_url

vertexi_ai_forwarding_url = API_URL + "predictWithVertexEndpoints"


class Client:
    """Prediction client to send the request with."""

    def __init__(
        self, api_key: str | None = None, token_manager: TokenManager | None = None
    ) -> None:
        self.api_key = api_key
        self.token_manager = token_manager

    @classmethod
    def authenticate(cls) -> Client:
        """Instantiates an authenticated client."""
        if FOLDING_API_KEY:
            return cls.from_api_key(api_key=FOLDING_API_KEY)
        else:
            return cls.from_jwt()

    @classmethod
    def from_api_key(cls, api_key: str) -> Client:
        """Instantiates a Client object using an API key."""
        return cls(api_key=api_key, token_manager=None)

    @classmethod
    def from_jwt(cls) -> Client:
        """Instantiates a Client object using a Google Cloud JWT token."""
        token_manager = TokenManager()
        return cls(token_manager=token_manager)

    def send_request(
        self, query: Query, project_code: str, timeout: int = REQUEST_TIMEOUT
    ) -> Response:
        """Sends a request to the endpoint, handling authentication and errors

        Args:
            query (Query): Folding query to send.
            project_code (str): Project code to associate to the query.
            timeout (int, optional): Request timeout. Defaults to REQUEST_TIMEOUT.

        Returns:
            Response: Prediction endpoint response.
        """
        headers = {}
        if self.api_key:
            headers["X-API-Key"] = self.api_key
        elif self.token_manager:
            headers["Authorization"] = f"Bearer {self.token_manager.get_token()}"
        params = {"project_code": project_code, "model": query.MODEL}

        try:
            response = requests.post(
                url=vertexi_ai_forwarding_url,
                params=params,
                json=query.payload,
                headers=headers,
                timeout=timeout,
            )
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx, 5xx)
            json_response = response.json()

            return Response(
                output_signed_url=json_response["signed_url"],
                confidence_data=json_response["confidence_data"],
            )
        except requests.exceptions.RequestException as e:
            logging.error(f"Error sending request: {e}")
            raise e


class Response:
    """Class to handle the endpoints JSON responses."""

    def __init__(self, output_signed_url: str, confidence_data: Dict[str, Any]) -> None:
        self.output_signed_url = output_signed_url
        self._confidence_data = confidence_data
        self.unzip_folder_name = ""

    def download_results(
        self, output_dir: Path, *, force: bool = False, unzip: bool = False
    ) -> None:
        """Downloads and optionally unzips the result file.

        Args:
            output_dir (Path): Path where the file will be saved.
            force (bool): Overwrite existing file if True.
            unzip (bool): Extract contents if the file is a zip.

        Raises:
            typer.Exit: If an error occurs or file already exists without `force`.
        """
        output_path = (
            output_dir / f"results_{datetime.now().strftime('%Y%m%d%H%M%S')}.zip"
        )
        download_file_from_signed_url(
            self.output_signed_url,
            output_path,
            force=force,
            unzip=unzip,
            unzip_dir=output_dir,
        )

    @property
    def confidence_data(self) -> Dict[str, Any]:
        """Prediction confidence data."""
        return self._confidence_data