File size: 2,736 Bytes
dbb98e1
 
d44011a
dbb98e1
 
 
 
 
 
 
d44011a
 
 
 
 
 
dbb98e1
 
 
 
95b7f82
dbb98e1
 
 
 
 
 
 
 
 
 
 
 
d44011a
 
 
 
 
 
dbb98e1
 
 
 
 
95b7f82
 
 
dbb98e1
 
 
 
 
 
 
 
 
 
 
 
d44011a
 
 
 
 
 
37e134d
d44011a
 
37e134d
d44011a
 
 
 
 
 
dbb98e1
95b7f82
d44011a
 
 
dbb98e1
 
37e134d
dbb98e1
 
 
37e134d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Script to download external data for the project at build time."""

import argparse
import logging
import os
import tarfile

import wget


def download_and_extract_models(models_url: str) -> None:
    """Downloads the models folder from the server and extracts it.

    Args:
        models_url: URL to download the models from.
    """
    logging.debug("Downloading models folder.")
    models_targz = "models.tar.gz"
    models_folder = "data/models/"
    try:
        logging.debug(f"Downloading models from {models_url}.")
        wget.download(models_url, models_targz)

        logging.debug("Extracting models folder.")
        with tarfile.open(models_targz, "r:gz") as tar:
            tar.extractall(models_folder)

        os.remove(models_targz)
        logging.debug("Models folder downloaded and extracted.")
    except Exception as e:
        logging.error(f"Error downloading models folder: {e}")


def download_and_extract_item_embeddings(item_embeddings_url: str) -> None:
    """Downloads the item embeddings folder from the server and extracts it.

    Args:
        item_embeddings_url: URL to download the item embeddings from.
    """
    logging.debug("Downloading item embeddings folder.")
    item_embeddings_tarbz = "item_embeddings.tar.bz2"
    item_embeddings_folder = "data/"

    try:
        logging.debug(
            f"Downloading item embeddings from {item_embeddings_url}."
        )
        wget.download(item_embeddings_url, item_embeddings_tarbz)

        logging.debug("Extracting item embeddings folder.")
        with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
            tar.extractall(item_embeddings_folder)

        os.remove(item_embeddings_tarbz)
        logging.debug("Item embeddings folder downloaded and extracted.")
    except Exception as e:
        logging.error(f"Error downloading item embeddings folder: {e}")


def parse_args() -> argparse.Namespace:
    """Parses command line arguments."""
    parser = argparse.ArgumentParser(
        description="Download external data for the project."
    )
    parser.add_argument(
        "models", type=str, help="URL to download the models folder."
    )
    parser.add_argument(
        "embeddings",
        type=str,
        help="URL to download the item embeddings folder",
    )
    return parser.parse_args()


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)

    args = parse_args()

    if not os.path.exists("data/models"):
        logging.info("Downloading models...")
        download_and_extract_models(args.models)

    if not os.path.exists("data/embed_items"):
        logging.info("Downloading item embeddings...")
        download_and_extract_item_embeddings(args.embeddings)