Nolwenn commited on
Commit
d44011a
·
1 Parent(s): 95b7f82

Update data download

Browse files
Files changed (2) hide show
  1. Dockerfile +6 -6
  2. download_external_data.py +34 -8
Dockerfile CHANGED
@@ -6,12 +6,12 @@ COPY . .
6
  # Install requirements
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
9
- # Expose Hugging Face Space secrets to environment variables
10
- RUN --mount=type=secret,id=models_folder_url,mode=0444,required=true echo "MODELS_FOLDER_URL=$(cat /run/secrets/models_folder_url)" >> .env
11
- RUN --mount=type=secret,id=item_embeddings_url,mode=0444,required=true echo "ITEM_EMBEDDINGS_URL=$(cat /run/secrets/item_embeddings_url)" >> .env
12
-
13
- # Download external data
14
- RUN python download_external_data.py
15
 
16
  EXPOSE 7860
17
 
 
6
  # Install requirements
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
9
+ # Expose Hugging Face Space secrets and download external data
10
+ RUN --mount=type=secret,id=models_folder_url,mode=0444,required=true \
11
+ --mount=type=secret,id=item_embeddings_url,mode=0444,required=true \
12
+ python download_external_data.py \
13
+ $(cat /run/secrets/models_folder_url) \
14
+ $(cat /run/secrets/item_embeddings_url)
15
 
16
  EXPOSE 7860
17
 
download_external_data.py CHANGED
@@ -1,5 +1,6 @@
1
  """Script to download external data for the project at build time."""
2
 
 
3
  import logging
4
  import os
5
  import tarfile
@@ -7,10 +8,13 @@ import tarfile
7
  import wget
8
 
9
 
10
- def download_and_extract_models() -> None:
11
- """Downloads the models folder from the server and extracts it."""
 
 
 
 
12
  logging.debug("Downloading models folder.")
13
- models_url = os.environ.get("MODELS_FOLDER_URL")
14
  models_targz = "models.tar.gz"
15
  models_folder = "data/models/"
16
  try:
@@ -27,10 +31,13 @@ def download_and_extract_models() -> None:
27
  logging.error(f"Error downloading models folder: {e}")
28
 
29
 
30
- def download_and_extract_item_embeddings() -> None:
31
- """Downloads the item embeddings folder from the server and extracts it."""
 
 
 
 
32
  logging.debug("Downloading item embeddings folder.")
33
- item_embeddings_url = os.environ.get("ITEM_EMBEDDINGS_URL")
34
  item_embeddings_tarbz = "item_embeddings.tar.bz2"
35
  item_embeddings_folder = "data/"
36
 
@@ -50,12 +57,31 @@ def download_and_extract_item_embeddings() -> None:
50
  logging.error(f"Error downloading item embeddings folder: {e}")
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if __name__ == "__main__":
54
  logging.basicConfig(level=logging.DEBUG)
 
 
 
55
  if not os.path.exists("data/models"):
56
  logging.info("Downloading models...")
57
- download_and_extract_models()
58
 
59
  if not os.path.exists("data/embed_items"):
60
  logging.info("Downloading item embeddings...")
61
- download_and_extract_item_embeddings()
 
1
  """Script to download external data for the project at build time."""
2
 
3
+ import argparse
4
  import logging
5
  import os
6
  import tarfile
 
8
  import wget
9
 
10
 
11
+ def download_and_extract_models(models_url: str) -> None:
12
+ """Downloads the models folder from the server and extracts it.
13
+
14
+ Args:
15
+ models_url: URL to download the models from.
16
+ """
17
  logging.debug("Downloading models folder.")
 
18
  models_targz = "models.tar.gz"
19
  models_folder = "data/models/"
20
  try:
 
31
  logging.error(f"Error downloading models folder: {e}")
32
 
33
 
34
+ def download_and_extract_item_embeddings(item_embeddings_url: str) -> None:
35
+ """Downloads the item embeddings folder from the server and extracts it.
36
+
37
+ Args:
38
+ item_embeddings_url: URL to download the item embeddings from.
39
+ """
40
  logging.debug("Downloading item embeddings folder.")
 
41
  item_embeddings_tarbz = "item_embeddings.tar.bz2"
42
  item_embeddings_folder = "data/"
43
 
 
57
  logging.error(f"Error downloading item embeddings folder: {e}")
58
 
59
 
60
+ def parse_args() -> argparse.Namespace:
61
+ """Parses command line arguments."""
62
+ parser = argparse.ArgumentParser(
63
+ description="Download external data for the project."
64
+ )
65
+ parser.add_argument(
66
+ "models-url", type=str, help="URL to download the models folder."
67
+ )
68
+ parser.add_argument(
69
+ "item-embeddings-url",
70
+ type=str,
71
+ help="URL to download the item embeddings folder",
72
+ )
73
+ return parser.parse_args()
74
+
75
+
76
  if __name__ == "__main__":
77
  logging.basicConfig(level=logging.DEBUG)
78
+
79
+ args = parse_args()
80
+
81
  if not os.path.exists("data/models"):
82
  logging.info("Downloading models...")
83
+ download_and_extract_models(args.models_url)
84
 
85
  if not os.path.exists("data/embed_items"):
86
  logging.info("Downloading item embeddings...")
87
+ download_and_extract_item_embeddings(args.item_embeddings_url)