Nolwenn commited on
Commit
dbb98e1
·
1 Parent(s): 7aa20fe

Change to docker space

Browse files
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile to run CRS Arena
2
+ FROM python:3.9-bullseye
3
+
4
+ COPY . .
5
+
6
+ # Install requirements
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Expose Hugging Face Space secrets to environment variables
10
+ RUN --mount=type=secret,id=models_folder_url,mode=0444,required=true echo "MODELS_FOLDER_URL=$(cat /run/secrets/models_folder_url)" >> .env
11
+ RUN --mount=type=secret,id=item_embeddings_url,mode=0444,required=true echo "ITEM_EMBEDDINGS_URL=$(cat /run/secrets/item_embeddings_url)" >> .env
12
+
13
+ # Download external data
14
+ RUN python download_external_data.py
15
+
16
+ EXPOSE 7860
17
+
18
+ # Run Streamlit app
19
+ CMD ["python", "-m", "streamlit", "run", "crs_arena.arena", "--server.port", "7860"]
README.md CHANGED
@@ -3,11 +3,10 @@ title: CRSArena
3
  emoji: 🐠
4
  colorFrom: yellow
5
  colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.42.2
8
- app_file: crs_arena/arena.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
 
3
  emoji: 🐠
4
  colorFrom: yellow
5
  colorTo: yellow
6
+ sdk: docker
7
+ app_port: 7860
 
8
  pinned: false
9
  license: mit
10
  ---
11
 
12
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
crs_arena/arena.py CHANGED
@@ -22,8 +22,6 @@ import asyncio
22
  import json
23
  import logging
24
  import os
25
- import threading
26
- import time
27
  from copy import deepcopy
28
  from datetime import datetime
29
  from typing import Dict, List
@@ -37,12 +35,7 @@ from battle_manager import (
37
  )
38
  from crs_fighter import CRSFighter
39
  from streamlit_lottie import st_lottie_spinner
40
- from utils import (
41
- download_and_extract_item_embeddings,
42
- download_and_extract_models,
43
- upload_conversation_logs_to_hf,
44
- upload_feedback_to_gsheet,
45
- )
46
 
47
  from src.model.crb_crs.recommender import *
48
 
@@ -56,14 +49,6 @@ logging.basicConfig(
56
  logger = logging.getLogger(__name__)
57
  logger.setLevel(logging.INFO)
58
 
59
- # Download models and data externally stored if not already downloaded
60
- if not os.path.exists("data/models"):
61
- logger.info("Downloading models...")
62
- download_and_extract_models()
63
- if not os.path.exists("data/embed_items"):
64
- logger.info("Downloading item embeddings...")
65
- download_and_extract_item_embeddings()
66
-
67
  # Create the conversation logs directory
68
  CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
69
  os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
@@ -82,7 +67,9 @@ def record_vote(vote: str) -> None:
82
  crs1_model: CRSFighter = st.session_state["crs1"]
83
  crs2_model: CRSFighter = st.session_state["crs2"]
84
  last_row_id = str(datetime.now())
85
- logger.info(f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}, {vote}")
 
 
86
  asyncio.run(
87
  upload_feedback_to_gsheet(
88
  {
@@ -189,6 +176,7 @@ def get_crs_response(crs: CRSFighter, message: str) -> str:
189
  # time.sleep(0.05)
190
  return response
191
 
 
192
  @st.dialog("Your vote has been submitted! Thank you!")
193
  def feedback_dialog(row_id: int) -> None:
194
  """Pop-up dialog to provide feedback after voting.
@@ -208,6 +196,7 @@ def feedback_dialog(row_id: int) -> None:
208
  st.session_state.clear()
209
  st.rerun()
210
 
 
211
  @st.fragment
212
  def chat_col(crs_id: int, color: str):
213
  """Chat column for the CRS model.
 
22
  import json
23
  import logging
24
  import os
 
 
25
  from copy import deepcopy
26
  from datetime import datetime
27
  from typing import Dict, List
 
35
  )
36
  from crs_fighter import CRSFighter
37
  from streamlit_lottie import st_lottie_spinner
38
+ from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet
 
 
 
 
 
39
 
40
  from src.model.crb_crs.recommender import *
41
 
 
49
  logger = logging.getLogger(__name__)
50
  logger.setLevel(logging.INFO)
51
 
 
 
 
 
 
 
 
 
52
  # Create the conversation logs directory
53
  CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
54
  os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
 
67
  crs1_model: CRSFighter = st.session_state["crs1"]
68
  crs2_model: CRSFighter = st.session_state["crs2"]
69
  last_row_id = str(datetime.now())
70
+ logger.info(
71
+ f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}, {vote}"
72
+ )
73
  asyncio.run(
74
  upload_feedback_to_gsheet(
75
  {
 
176
  # time.sleep(0.05)
177
  return response
178
 
179
+
180
  @st.dialog("Your vote has been submitted! Thank you!")
181
  def feedback_dialog(row_id: int) -> None:
182
  """Pop-up dialog to provide feedback after voting.
 
196
  st.session_state.clear()
197
  st.rerun()
198
 
199
+
200
  @st.fragment
201
  def chat_col(crs_id: int, color: str):
202
  """Chat column for the CRS model.
crs_arena/utils.py CHANGED
@@ -4,16 +4,13 @@ import ast
4
  import asyncio
5
  import logging
6
  import os
7
- import sqlite3
8
  import sys
9
- import tarfile
10
  from datetime import timedelta
11
- from typing import Any, Dict, List
12
 
13
  import openai
14
  import pandas as pd
15
  import streamlit as st
16
- import wget
17
  import yaml
18
  from huggingface_hub import HfApi
19
  from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
@@ -23,7 +20,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
23
  from src.model.crs_model import CRSModel
24
 
25
  # Initialize Hugging Face API
26
- HF_API = HfApi(token=st.secrets["hf_token"])
27
 
28
 
29
  @st.cache_resource(
@@ -52,52 +49,12 @@ def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
52
  model_args = yaml.safe_load(open(model_config_file, "r"))
53
 
54
  if "chatgpt" in model_name:
55
- openai.api_key = st.secrets["openai_api_key"]
56
 
57
  # Extract crs model from name
58
  name = model_name.split("_")[0]
59
 
60
- return CRSModel(name, **model_args), ttl=timedelta(days=3)
61
- )
62
-
63
-
64
- def download_and_extract_models() -> None:
65
- """Downloads the models folder from the server and extracts it."""
66
- logging.debug("Downloading models folder.")
67
- models_url = st.secrets["models_folder_url"]
68
- models_targz = "models.tar.gz"
69
- models_folder = "data/models/"
70
- try:
71
- wget.download(models_url, models_targz)
72
-
73
- logging.debug("Extracting models folder.")
74
- with tarfile.open(models_targz, "r:gz") as tar:
75
- tar.extractall(models_folder)
76
-
77
- os.remove(models_targz)
78
- logging.debug("Models folder downloaded and extracted.")
79
- except Exception as e:
80
- logging.error(f"Error downloading models folder: {e}")
81
-
82
-
83
- def download_and_extract_item_embeddings() -> None:
84
- """Downloads the item embeddings folder from the server and extracts it."""
85
- logging.debug("Downloading item embeddings folder.")
86
- item_embeddings_url = st.secrets["item_embeddings_url"]
87
- item_embeddings_tarbz = "item_embeddings.tar.bz2"
88
- item_embeddings_folder = "data/"
89
-
90
- try:
91
- wget.download(item_embeddings_url, item_embeddings_tarbz)
92
-
93
- logging.debug("Extracting item embeddings folder.")
94
- with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
95
- tar.extractall(item_embeddings_folder)
96
-
97
- os.remove(item_embeddings_tarbz)
98
- logging.debug("Item embeddings folder downloaded and extracted.")
99
- except Exception as e:
100
- logging.error(f"Error downloading item embeddings folder: {e}")
101
 
102
 
103
  async def upload_conversation_logs_to_hf(
@@ -122,7 +79,7 @@ async def upload_conversation_logs_to_hf(
122
  lambda: HF_API.upload_file(
123
  path_or_fileobj=conversation_log_file_path,
124
  path_in_repo=repo_filename,
125
- repo_id=st.secrets["dataset_repo"],
126
  repo_type="dataset",
127
  ),
128
  )
@@ -164,7 +121,7 @@ def _upload_feedback_to_gsheet_sync(
164
  worksheet: Name of the worksheet to upload the feedback to.
165
  """
166
  gs_connection = GSheetsServiceAccountClient(
167
- ast.literal_eval(st.secrets["gsheet"])
168
  )
169
  df = gs_connection.read(worksheet=worksheet)
170
  if df[df["id"] == row["id"]].empty:
 
4
  import asyncio
5
  import logging
6
  import os
 
7
  import sys
 
8
  from datetime import timedelta
9
+ from typing import Dict
10
 
11
  import openai
12
  import pandas as pd
13
  import streamlit as st
 
14
  import yaml
15
  from huggingface_hub import HfApi
16
  from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
 
20
  from src.model.crs_model import CRSModel
21
 
22
  # Initialize Hugging Face API
23
+ HF_API = HfApi(token=os.environ.get("hf_token"))
24
 
25
 
26
  @st.cache_resource(
 
49
  model_args = yaml.safe_load(open(model_config_file, "r"))
50
 
51
  if "chatgpt" in model_name:
52
+ openai.api_key = os.environ.get("openai_api_key")
53
 
54
  # Extract crs model from name
55
  name = model_name.split("_")[0]
56
 
57
+ return CRSModel(name, **model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  async def upload_conversation_logs_to_hf(
 
79
  lambda: HF_API.upload_file(
80
  path_or_fileobj=conversation_log_file_path,
81
  path_in_repo=repo_filename,
82
+ repo_id=os.environ.get("dataset_repo"),
83
  repo_type="dataset",
84
  ),
85
  )
 
121
  worksheet: Name of the worksheet to upload the feedback to.
122
  """
123
  gs_connection = GSheetsServiceAccountClient(
124
+ ast.literal_eval(os.environ.get("gsheet"))
125
  )
126
  df = gs_connection.read(worksheet=worksheet)
127
  if df[df["id"] == row["id"]].empty:
download_external_data.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to download external data for the project at build time."""
2
+
3
+ import logging
4
+ import os
5
+ import tarfile
6
+
7
+ import wget
8
+
9
+
10
+ def download_and_extract_models() -> None:
11
+ """Downloads the models folder from the server and extracts it."""
12
+ logging.debug("Downloading models folder.")
13
+ models_url = os.environ.get("MODELS_FOLDER_URL")
14
+ models_targz = "models.tar.gz"
15
+ models_folder = "data/models/"
16
+ try:
17
+ wget.download(models_url, models_targz)
18
+
19
+ logging.debug("Extracting models folder.")
20
+ with tarfile.open(models_targz, "r:gz") as tar:
21
+ tar.extractall(models_folder)
22
+
23
+ os.remove(models_targz)
24
+ logging.debug("Models folder downloaded and extracted.")
25
+ except Exception as e:
26
+ logging.error(f"Error downloading models folder: {e}")
27
+
28
+
29
+ def download_and_extract_item_embeddings() -> None:
30
+ """Downloads the item embeddings folder from the server and extracts it."""
31
+ logging.debug("Downloading item embeddings folder.")
32
+ item_embeddings_url = os.environ.get("ITEM_EMBEDDINGS_URL")
33
+ item_embeddings_tarbz = "item_embeddings.tar.bz2"
34
+ item_embeddings_folder = "data/"
35
+
36
+ try:
37
+ wget.download(item_embeddings_url, item_embeddings_tarbz)
38
+
39
+ logging.debug("Extracting item embeddings folder.")
40
+ with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
41
+ tar.extractall(item_embeddings_folder)
42
+
43
+ os.remove(item_embeddings_tarbz)
44
+ logging.debug("Item embeddings folder downloaded and extracted.")
45
+ except Exception as e:
46
+ logging.error(f"Error downloading item embeddings folder: {e}")
47
+
48
+
49
+ if __name__ == "__main__":
50
+ if not os.path.exists("data/models"):
51
+ logging.info("Downloading models...")
52
+ download_and_extract_models()
53
+
54
+ if not os.path.exists("data/embed_items"):
55
+ logging.info("Downloading item embeddings...")
56
+ download_and_extract_item_embeddings()
requirements.txt CHANGED
@@ -10,7 +10,7 @@ tiktoken==0.7.0
10
  tenacity<9.0.0
11
  thefuzz==0.22.1
12
  numpy<2
13
- streamlit==1.42.2
14
  SQLAlchemy==1.4.0
15
  sent2vec==0.3.0
16
  wget==3.2
 
10
  tenacity<9.0.0
11
  thefuzz==0.22.1
12
  numpy<2
13
+ streamlit==1.38.0
14
  SQLAlchemy==1.4.0
15
  sent2vec==0.3.0
16
  wget==3.2