|
"""Script to download external data for the project at build time.""" |
|
|
|
import argparse |
|
import logging |
|
import os |
|
import tarfile |
|
|
|
import wget |
|
|
|
|
|
def download_and_extract_models(models_url: str) -> None: |
|
"""Downloads the models folder from the server and extracts it. |
|
|
|
Args: |
|
models_url: URL to download the models from. |
|
""" |
|
logging.debug("Downloading models folder.") |
|
models_targz = "models.tar.gz" |
|
models_folder = "data/models/" |
|
try: |
|
logging.debug(f"Downloading models from {models_url}.") |
|
wget.download(models_url, models_targz) |
|
|
|
logging.debug("Extracting models folder.") |
|
with tarfile.open(models_targz, "r:gz") as tar: |
|
tar.extractall(models_folder) |
|
|
|
os.remove(models_targz) |
|
logging.debug("Models folder downloaded and extracted.") |
|
except Exception as e: |
|
logging.error(f"Error downloading models folder: {e}") |
|
|
|
|
|
def download_and_extract_item_embeddings(item_embeddings_url: str) -> None: |
|
"""Downloads the item embeddings folder from the server and extracts it. |
|
|
|
Args: |
|
item_embeddings_url: URL to download the item embeddings from. |
|
""" |
|
logging.debug("Downloading item embeddings folder.") |
|
item_embeddings_tarbz = "item_embeddings.tar.bz2" |
|
item_embeddings_folder = "data/" |
|
|
|
try: |
|
logging.debug( |
|
f"Downloading item embeddings from {item_embeddings_url}." |
|
) |
|
wget.download(item_embeddings_url, item_embeddings_tarbz) |
|
|
|
logging.debug("Extracting item embeddings folder.") |
|
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar: |
|
tar.extractall(item_embeddings_folder) |
|
|
|
os.remove(item_embeddings_tarbz) |
|
logging.debug("Item embeddings folder downloaded and extracted.") |
|
except Exception as e: |
|
logging.error(f"Error downloading item embeddings folder: {e}") |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""Parses command line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="Download external data for the project." |
|
) |
|
parser.add_argument( |
|
"models", type=str, help="URL to download the models folder." |
|
) |
|
parser.add_argument( |
|
"embeddings", |
|
type=str, |
|
help="URL to download the item embeddings folder", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
args = parse_args() |
|
|
|
if not os.path.exists("data/models"): |
|
logging.info("Downloading models...") |
|
download_and_extract_models(args.models) |
|
|
|
if not os.path.exists("data/embed_items"): |
|
logging.info("Downloading item embeddings...") |
|
download_and_extract_item_embeddings(args.embeddings) |
|
|