Last commit not found
"""Script to download external data for the project at build time.""" | |
import argparse | |
import logging | |
import os | |
import tarfile | |
import wget | |
def download_and_extract_models(models_url: str) -> None: | |
"""Downloads the models folder from the server and extracts it. | |
Args: | |
models_url: URL to download the models from. | |
""" | |
logging.debug("Downloading models folder.") | |
models_targz = "models.tar.gz" | |
models_folder = "data/models/" | |
try: | |
logging.debug(f"Downloading models from {models_url}.") | |
wget.download(models_url, models_targz) | |
logging.debug("Extracting models folder.") | |
with tarfile.open(models_targz, "r:gz") as tar: | |
tar.extractall(models_folder) | |
os.remove(models_targz) | |
logging.debug("Models folder downloaded and extracted.") | |
except Exception as e: | |
logging.error(f"Error downloading models folder: {e}") | |
def download_and_extract_item_embeddings(item_embeddings_url: str) -> None: | |
"""Downloads the item embeddings folder from the server and extracts it. | |
Args: | |
item_embeddings_url: URL to download the item embeddings from. | |
""" | |
logging.debug("Downloading item embeddings folder.") | |
item_embeddings_tarbz = "item_embeddings.tar.bz2" | |
item_embeddings_folder = "data/" | |
try: | |
logging.debug( | |
f"Downloading item embeddings from {item_embeddings_url}." | |
) | |
wget.download(item_embeddings_url, item_embeddings_tarbz) | |
logging.debug("Extracting item embeddings folder.") | |
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar: | |
tar.extractall(item_embeddings_folder) | |
os.remove(item_embeddings_tarbz) | |
logging.debug("Item embeddings folder downloaded and extracted.") | |
except Exception as e: | |
logging.error(f"Error downloading item embeddings folder: {e}") | |
def parse_args() -> argparse.Namespace: | |
"""Parses command line arguments.""" | |
parser = argparse.ArgumentParser( | |
description="Download external data for the project." | |
) | |
parser.add_argument( | |
"models", type=str, help="URL to download the models folder." | |
) | |
parser.add_argument( | |
"embeddings", | |
type=str, | |
help="URL to download the item embeddings folder", | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.DEBUG) | |
args = parse_args() | |
if not os.path.exists("data/models"): | |
logging.info("Downloading models...") | |
download_and_extract_models(args.models) | |
if not os.path.exists("data/embed_items"): | |
logging.info("Downloading item embeddings...") | |
download_and_extract_item_embeddings(args.embeddings) | |