CRSArena / download_external_data.py
Nolwenn
Update args name
37e134d
raw
history blame
2.74 kB
"""Script to download external data for the project at build time."""
import argparse
import logging
import os
import tarfile
import wget
def download_and_extract_models(models_url: str) -> None:
"""Downloads the models folder from the server and extracts it.
Args:
models_url: URL to download the models from.
"""
logging.debug("Downloading models folder.")
models_targz = "models.tar.gz"
models_folder = "data/models/"
try:
logging.debug(f"Downloading models from {models_url}.")
wget.download(models_url, models_targz)
logging.debug("Extracting models folder.")
with tarfile.open(models_targz, "r:gz") as tar:
tar.extractall(models_folder)
os.remove(models_targz)
logging.debug("Models folder downloaded and extracted.")
except Exception as e:
logging.error(f"Error downloading models folder: {e}")
def download_and_extract_item_embeddings(item_embeddings_url: str) -> None:
"""Downloads the item embeddings folder from the server and extracts it.
Args:
item_embeddings_url: URL to download the item embeddings from.
"""
logging.debug("Downloading item embeddings folder.")
item_embeddings_tarbz = "item_embeddings.tar.bz2"
item_embeddings_folder = "data/"
try:
logging.debug(
f"Downloading item embeddings from {item_embeddings_url}."
)
wget.download(item_embeddings_url, item_embeddings_tarbz)
logging.debug("Extracting item embeddings folder.")
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
tar.extractall(item_embeddings_folder)
os.remove(item_embeddings_tarbz)
logging.debug("Item embeddings folder downloaded and extracted.")
except Exception as e:
logging.error(f"Error downloading item embeddings folder: {e}")
def parse_args() -> argparse.Namespace:
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
description="Download external data for the project."
)
parser.add_argument(
"models", type=str, help="URL to download the models folder."
)
parser.add_argument(
"embeddings",
type=str,
help="URL to download the item embeddings folder",
)
return parser.parse_args()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
args = parse_args()
if not os.path.exists("data/models"):
logging.info("Downloading models...")
download_and_extract_models(args.models)
if not os.path.exists("data/embed_items"):
logging.info("Downloading item embeddings...")
download_and_extract_item_embeddings(args.embeddings)