"""Script to download external data for the project at build time.""" import argparse import logging import os import tarfile import wget def download_and_extract_models(models_url: str) -> None: """Downloads the models folder from the server and extracts it. Args: models_url: URL to download the models from. """ logging.debug("Downloading models folder.") models_targz = "models.tar.gz" models_folder = "data/models/" try: logging.debug(f"Downloading models from {models_url}.") wget.download(models_url, models_targz) logging.debug("Extracting models folder.") with tarfile.open(models_targz, "r:gz") as tar: tar.extractall(models_folder) os.remove(models_targz) logging.debug("Models folder downloaded and extracted.") except Exception as e: logging.error(f"Error downloading models folder: {e}") def download_and_extract_item_embeddings(item_embeddings_url: str) -> None: """Downloads the item embeddings folder from the server and extracts it. Args: item_embeddings_url: URL to download the item embeddings from. """ logging.debug("Downloading item embeddings folder.") item_embeddings_tarbz = "item_embeddings.tar.bz2" item_embeddings_folder = "data/" try: logging.debug( f"Downloading item embeddings from {item_embeddings_url}." ) wget.download(item_embeddings_url, item_embeddings_tarbz) logging.debug("Extracting item embeddings folder.") with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar: tar.extractall(item_embeddings_folder) os.remove(item_embeddings_tarbz) logging.debug("Item embeddings folder downloaded and extracted.") except Exception as e: logging.error(f"Error downloading item embeddings folder: {e}") def parse_args() -> argparse.Namespace: """Parses command line arguments.""" parser = argparse.ArgumentParser( description="Download external data for the project." ) parser.add_argument( "models", type=str, help="URL to download the models folder." ) parser.add_argument( "embeddings", type=str, help="URL to download the item embeddings folder", ) return parser.parse_args() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) args = parse_args() if not os.path.exists("data/models"): logging.info("Downloading models...") download_and_extract_models(args.models) if not os.path.exists("data/embed_items"): logging.info("Downloading item embeddings...") download_and_extract_item_embeddings(args.embeddings)