File size: 3,396 Bytes
44459bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Helper methods for GCP."""

import logging
import shutil
import subprocess
import time
import zipfile
from pathlib import Path

import requests

from folding_studio.config import REQUEST_TIMEOUT

TOKEN_EXPIRY_SECONDS = 15 * 60  # 15 Minutes


class TokenManager:
    """Class to handle token updating."""

    def __init__(self) -> None:
        """Initialize TokenManager class.

        Args:
            host_url: the url to obtain the token for.
        """
        self.access_token = None
        self.token_generation_time = 0

    def get_token(self) -> str:
        """Get the token (self updating every 15 mins).

        Return:
            The updated token
        """
        current_time = time.time()
        # Check if the token is expired
        if (
            self.access_token is None
            or current_time - self.token_generation_time >= TOKEN_EXPIRY_SECONDS
        ):
            self.access_token = get_id_token()

        return self.access_token


def get_id_token() -> str:
    """Get the user's gcp token id.

    Returns:
        str: The user's gcp token id.
    """
    cmd_output = subprocess.run(
        ["gcloud", "auth", "print-identity-token"],
        capture_output=True,
        text=True,
        check=False,
    )
    return cmd_output.stdout.strip()


def download_file_from_signed_url(
    signed_url: str,
    output_path: Path,
    force: bool = False,
    unzip: bool = False,
    unzip_dir: str | None = None,
) -> None:
    """Download a file from a signed url.

    Args:
        signed_url (str): GCP signed url.
        output_path (Path): Output file path.
        force (bool, optional): Force file writing if it already exists.Defaults to False.
        unzip (bool, optional): Unzip the zip file after downloading. Defaults to False.
        unzip_dir (str | None, optional): Directory where to extract all members of the archive.
            Defaults to None.

    Raises:
        ValueError: If output file path exists but force set to false.
        ValueError: If unzip but the output path is not a zip file.
        Exception: If an error occurs during the download.
        ValueError: If unzip but the downloaded file is not a valid zip archive.
    """
    if output_path.exists() and not force:
        msg = f"The file '{output_path}' already exists. Use the --force flag to overwrite it."
        raise ValueError(msg)

    if unzip and not output_path.suffix == ".zip":
        msg = "The output path must be a zip file."
        raise ValueError(msg)

    unzip_dir = unzip_dir or output_path.with_suffix("")

    try:
        response = requests.get(signed_url, stream=True, timeout=REQUEST_TIMEOUT)
        response.raise_for_status()
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with output_path.open("wb") as f:
            shutil.copyfileobj(response.raw, f)
    except Exception as e:
        msg = f"Error downloading from signed url: {e}"
        raise Exception(msg) from e

    if unzip:
        unzip_dir.mkdir(parents=True, exist_ok=True)
        try:
            with zipfile.ZipFile(output_path, "r") as zip_ref:
                zip_ref.extractall(unzip_dir)
        except zipfile.BadZipFile:
            msg = f"File {output_path} is not a valid zip archive."
            raise ValueError(msg)

        logging.info(f"Extracted all files of {output_path} to {unzip_dir}.")