|
"""Helper methods for GCP.""" |
|
|
|
import logging |
|
import shutil |
|
import subprocess |
|
import time |
|
import zipfile |
|
from pathlib import Path |
|
|
|
import requests |
|
|
|
from folding_studio.config import REQUEST_TIMEOUT |
|
|
|
TOKEN_EXPIRY_SECONDS = 15 * 60 |
|
|
|
|
|
class TokenManager: |
|
"""Class to handle token updating.""" |
|
|
|
def __init__(self) -> None: |
|
"""Initialize TokenManager class. |
|
|
|
Args: |
|
host_url: the url to obtain the token for. |
|
""" |
|
self.access_token = None |
|
self.token_generation_time = 0 |
|
|
|
def get_token(self) -> str: |
|
"""Get the token (self updating every 15 mins). |
|
|
|
Return: |
|
The updated token |
|
""" |
|
current_time = time.time() |
|
|
|
if ( |
|
self.access_token is None |
|
or current_time - self.token_generation_time >= TOKEN_EXPIRY_SECONDS |
|
): |
|
self.access_token = get_id_token() |
|
|
|
return self.access_token |
|
|
|
|
|
def get_id_token() -> str: |
|
"""Get the user's gcp token id. |
|
|
|
Returns: |
|
str: The user's gcp token id. |
|
""" |
|
cmd_output = subprocess.run( |
|
["gcloud", "auth", "print-identity-token"], |
|
capture_output=True, |
|
text=True, |
|
check=False, |
|
) |
|
return cmd_output.stdout.strip() |
|
|
|
|
|
def download_file_from_signed_url( |
|
signed_url: str, |
|
output_path: Path, |
|
force: bool = False, |
|
unzip: bool = False, |
|
unzip_dir: str | None = None, |
|
) -> None: |
|
"""Download a file from a signed url. |
|
|
|
Args: |
|
signed_url (str): GCP signed url. |
|
output_path (Path): Output file path. |
|
force (bool, optional): Force file writing if it already exists.Defaults to False. |
|
unzip (bool, optional): Unzip the zip file after downloading. Defaults to False. |
|
unzip_dir (str | None, optional): Directory where to extract all members of the archive. |
|
Defaults to None. |
|
|
|
Raises: |
|
ValueError: If output file path exists but force set to false. |
|
ValueError: If unzip but the output path is not a zip file. |
|
Exception: If an error occurs during the download. |
|
ValueError: If unzip but the downloaded file is not a valid zip archive. |
|
""" |
|
if output_path.exists() and not force: |
|
msg = f"The file '{output_path}' already exists. Use the --force flag to overwrite it." |
|
raise ValueError(msg) |
|
|
|
if unzip and not output_path.suffix == ".zip": |
|
msg = "The output path must be a zip file." |
|
raise ValueError(msg) |
|
|
|
unzip_dir = unzip_dir or output_path.with_suffix("") |
|
|
|
try: |
|
response = requests.get(signed_url, stream=True, timeout=REQUEST_TIMEOUT) |
|
response.raise_for_status() |
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
with output_path.open("wb") as f: |
|
shutil.copyfileobj(response.raw, f) |
|
except Exception as e: |
|
msg = f"Error downloading from signed url: {e}" |
|
raise Exception(msg) from e |
|
|
|
if unzip: |
|
unzip_dir.mkdir(parents=True, exist_ok=True) |
|
try: |
|
with zipfile.ZipFile(output_path, "r") as zip_ref: |
|
zip_ref.extractall(unzip_dir) |
|
except zipfile.BadZipFile: |
|
msg = f"File {output_path} is not a valid zip archive." |
|
raise ValueError(msg) |
|
|
|
logging.info(f"Extracted all files of {output_path} to {unzip_dir}.") |
|
|