File size: 2,736 Bytes
dbb98e1 d44011a dbb98e1 d44011a dbb98e1 95b7f82 dbb98e1 d44011a dbb98e1 95b7f82 dbb98e1 d44011a 37e134d d44011a 37e134d d44011a dbb98e1 95b7f82 d44011a dbb98e1 37e134d dbb98e1 37e134d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
"""Script to download external data for the project at build time."""
import argparse
import logging
import os
import tarfile
import wget
def download_and_extract_models(models_url: str) -> None:
"""Downloads the models folder from the server and extracts it.
Args:
models_url: URL to download the models from.
"""
logging.debug("Downloading models folder.")
models_targz = "models.tar.gz"
models_folder = "data/models/"
try:
logging.debug(f"Downloading models from {models_url}.")
wget.download(models_url, models_targz)
logging.debug("Extracting models folder.")
with tarfile.open(models_targz, "r:gz") as tar:
tar.extractall(models_folder)
os.remove(models_targz)
logging.debug("Models folder downloaded and extracted.")
except Exception as e:
logging.error(f"Error downloading models folder: {e}")
def download_and_extract_item_embeddings(item_embeddings_url: str) -> None:
"""Downloads the item embeddings folder from the server and extracts it.
Args:
item_embeddings_url: URL to download the item embeddings from.
"""
logging.debug("Downloading item embeddings folder.")
item_embeddings_tarbz = "item_embeddings.tar.bz2"
item_embeddings_folder = "data/"
try:
logging.debug(
f"Downloading item embeddings from {item_embeddings_url}."
)
wget.download(item_embeddings_url, item_embeddings_tarbz)
logging.debug("Extracting item embeddings folder.")
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
tar.extractall(item_embeddings_folder)
os.remove(item_embeddings_tarbz)
logging.debug("Item embeddings folder downloaded and extracted.")
except Exception as e:
logging.error(f"Error downloading item embeddings folder: {e}")
def parse_args() -> argparse.Namespace:
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
description="Download external data for the project."
)
parser.add_argument(
"models", type=str, help="URL to download the models folder."
)
parser.add_argument(
"embeddings",
type=str,
help="URL to download the item embeddings folder",
)
return parser.parse_args()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
args = parse_args()
if not os.path.exists("data/models"):
logging.info("Downloading models...")
download_and_extract_models(args.models)
if not os.path.exists("data/embed_items"):
logging.info("Downloading item embeddings...")
download_and_extract_item_embeddings(args.embeddings)
|