Nolwenn
commited on
Commit
·
dbb98e1
1
Parent(s):
7aa20fe
Change to docker space
Browse files- Dockerfile +19 -0
- README.md +3 -4
- crs_arena/arena.py +6 -17
- crs_arena/utils.py +6 -49
- download_external_data.py +56 -0
- requirements.txt +1 -1
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dockerfile to run CRS Arena
|
2 |
+
FROM python:3.9-bullseye
|
3 |
+
|
4 |
+
COPY . .
|
5 |
+
|
6 |
+
# Install requirements
|
7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
8 |
+
|
9 |
+
# Expose Hugging Face Space secrets to environment variables
|
10 |
+
RUN --mount=type=secret,id=models_folder_url,mode=0444,required=true echo "MODELS_FOLDER_URL=$(cat /run/secrets/models_folder_url)" >> .env
|
11 |
+
RUN --mount=type=secret,id=item_embeddings_url,mode=0444,required=true echo "ITEM_EMBEDDINGS_URL=$(cat /run/secrets/item_embeddings_url)" >> .env
|
12 |
+
|
13 |
+
# Download external data
|
14 |
+
RUN python download_external_data.py
|
15 |
+
|
16 |
+
EXPOSE 7860
|
17 |
+
|
18 |
+
# Run Streamlit app
|
19 |
+
CMD ["python", "-m", "streamlit", "run", "crs_arena.arena", "--server.port", "7860"]
|
README.md
CHANGED
@@ -3,11 +3,10 @@ title: CRSArena
|
|
3 |
emoji: 🐠
|
4 |
colorFrom: yellow
|
5 |
colorTo: yellow
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
app_file: crs_arena/arena.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
|
|
3 |
emoji: 🐠
|
4 |
colorFrom: yellow
|
5 |
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
|
|
8 |
pinned: false
|
9 |
license: mit
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
crs_arena/arena.py
CHANGED
@@ -22,8 +22,6 @@ import asyncio
|
|
22 |
import json
|
23 |
import logging
|
24 |
import os
|
25 |
-
import threading
|
26 |
-
import time
|
27 |
from copy import deepcopy
|
28 |
from datetime import datetime
|
29 |
from typing import Dict, List
|
@@ -37,12 +35,7 @@ from battle_manager import (
|
|
37 |
)
|
38 |
from crs_fighter import CRSFighter
|
39 |
from streamlit_lottie import st_lottie_spinner
|
40 |
-
from utils import
|
41 |
-
download_and_extract_item_embeddings,
|
42 |
-
download_and_extract_models,
|
43 |
-
upload_conversation_logs_to_hf,
|
44 |
-
upload_feedback_to_gsheet,
|
45 |
-
)
|
46 |
|
47 |
from src.model.crb_crs.recommender import *
|
48 |
|
@@ -56,14 +49,6 @@ logging.basicConfig(
|
|
56 |
logger = logging.getLogger(__name__)
|
57 |
logger.setLevel(logging.INFO)
|
58 |
|
59 |
-
# Download models and data externally stored if not already downloaded
|
60 |
-
if not os.path.exists("data/models"):
|
61 |
-
logger.info("Downloading models...")
|
62 |
-
download_and_extract_models()
|
63 |
-
if not os.path.exists("data/embed_items"):
|
64 |
-
logger.info("Downloading item embeddings...")
|
65 |
-
download_and_extract_item_embeddings()
|
66 |
-
|
67 |
# Create the conversation logs directory
|
68 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
69 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
@@ -82,7 +67,9 @@ def record_vote(vote: str) -> None:
|
|
82 |
crs1_model: CRSFighter = st.session_state["crs1"]
|
83 |
crs2_model: CRSFighter = st.session_state["crs2"]
|
84 |
last_row_id = str(datetime.now())
|
85 |
-
logger.info(
|
|
|
|
|
86 |
asyncio.run(
|
87 |
upload_feedback_to_gsheet(
|
88 |
{
|
@@ -189,6 +176,7 @@ def get_crs_response(crs: CRSFighter, message: str) -> str:
|
|
189 |
# time.sleep(0.05)
|
190 |
return response
|
191 |
|
|
|
192 |
@st.dialog("Your vote has been submitted! Thank you!")
|
193 |
def feedback_dialog(row_id: int) -> None:
|
194 |
"""Pop-up dialog to provide feedback after voting.
|
@@ -208,6 +196,7 @@ def feedback_dialog(row_id: int) -> None:
|
|
208 |
st.session_state.clear()
|
209 |
st.rerun()
|
210 |
|
|
|
211 |
@st.fragment
|
212 |
def chat_col(crs_id: int, color: str):
|
213 |
"""Chat column for the CRS model.
|
|
|
22 |
import json
|
23 |
import logging
|
24 |
import os
|
|
|
|
|
25 |
from copy import deepcopy
|
26 |
from datetime import datetime
|
27 |
from typing import Dict, List
|
|
|
35 |
)
|
36 |
from crs_fighter import CRSFighter
|
37 |
from streamlit_lottie import st_lottie_spinner
|
38 |
+
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
from src.model.crb_crs.recommender import *
|
41 |
|
|
|
49 |
logger = logging.getLogger(__name__)
|
50 |
logger.setLevel(logging.INFO)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Create the conversation logs directory
|
53 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
54 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
|
|
67 |
crs1_model: CRSFighter = st.session_state["crs1"]
|
68 |
crs2_model: CRSFighter = st.session_state["crs2"]
|
69 |
last_row_id = str(datetime.now())
|
70 |
+
logger.info(
|
71 |
+
f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}, {vote}"
|
72 |
+
)
|
73 |
asyncio.run(
|
74 |
upload_feedback_to_gsheet(
|
75 |
{
|
|
|
176 |
# time.sleep(0.05)
|
177 |
return response
|
178 |
|
179 |
+
|
180 |
@st.dialog("Your vote has been submitted! Thank you!")
|
181 |
def feedback_dialog(row_id: int) -> None:
|
182 |
"""Pop-up dialog to provide feedback after voting.
|
|
|
196 |
st.session_state.clear()
|
197 |
st.rerun()
|
198 |
|
199 |
+
|
200 |
@st.fragment
|
201 |
def chat_col(crs_id: int, color: str):
|
202 |
"""Chat column for the CRS model.
|
crs_arena/utils.py
CHANGED
@@ -4,16 +4,13 @@ import ast
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import os
|
7 |
-
import sqlite3
|
8 |
import sys
|
9 |
-
import tarfile
|
10 |
from datetime import timedelta
|
11 |
-
from typing import
|
12 |
|
13 |
import openai
|
14 |
import pandas as pd
|
15 |
import streamlit as st
|
16 |
-
import wget
|
17 |
import yaml
|
18 |
from huggingface_hub import HfApi
|
19 |
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
|
@@ -23,7 +20,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
23 |
from src.model.crs_model import CRSModel
|
24 |
|
25 |
# Initialize Hugging Face API
|
26 |
-
HF_API = HfApi(token=
|
27 |
|
28 |
|
29 |
@st.cache_resource(
|
@@ -52,52 +49,12 @@ def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
|
|
52 |
model_args = yaml.safe_load(open(model_config_file, "r"))
|
53 |
|
54 |
if "chatgpt" in model_name:
|
55 |
-
openai.api_key =
|
56 |
|
57 |
# Extract crs model from name
|
58 |
name = model_name.split("_")[0]
|
59 |
|
60 |
-
return CRSModel(name, **model_args)
|
61 |
-
)
|
62 |
-
|
63 |
-
|
64 |
-
def download_and_extract_models() -> None:
|
65 |
-
"""Downloads the models folder from the server and extracts it."""
|
66 |
-
logging.debug("Downloading models folder.")
|
67 |
-
models_url = st.secrets["models_folder_url"]
|
68 |
-
models_targz = "models.tar.gz"
|
69 |
-
models_folder = "data/models/"
|
70 |
-
try:
|
71 |
-
wget.download(models_url, models_targz)
|
72 |
-
|
73 |
-
logging.debug("Extracting models folder.")
|
74 |
-
with tarfile.open(models_targz, "r:gz") as tar:
|
75 |
-
tar.extractall(models_folder)
|
76 |
-
|
77 |
-
os.remove(models_targz)
|
78 |
-
logging.debug("Models folder downloaded and extracted.")
|
79 |
-
except Exception as e:
|
80 |
-
logging.error(f"Error downloading models folder: {e}")
|
81 |
-
|
82 |
-
|
83 |
-
def download_and_extract_item_embeddings() -> None:
|
84 |
-
"""Downloads the item embeddings folder from the server and extracts it."""
|
85 |
-
logging.debug("Downloading item embeddings folder.")
|
86 |
-
item_embeddings_url = st.secrets["item_embeddings_url"]
|
87 |
-
item_embeddings_tarbz = "item_embeddings.tar.bz2"
|
88 |
-
item_embeddings_folder = "data/"
|
89 |
-
|
90 |
-
try:
|
91 |
-
wget.download(item_embeddings_url, item_embeddings_tarbz)
|
92 |
-
|
93 |
-
logging.debug("Extracting item embeddings folder.")
|
94 |
-
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
|
95 |
-
tar.extractall(item_embeddings_folder)
|
96 |
-
|
97 |
-
os.remove(item_embeddings_tarbz)
|
98 |
-
logging.debug("Item embeddings folder downloaded and extracted.")
|
99 |
-
except Exception as e:
|
100 |
-
logging.error(f"Error downloading item embeddings folder: {e}")
|
101 |
|
102 |
|
103 |
async def upload_conversation_logs_to_hf(
|
@@ -122,7 +79,7 @@ async def upload_conversation_logs_to_hf(
|
|
122 |
lambda: HF_API.upload_file(
|
123 |
path_or_fileobj=conversation_log_file_path,
|
124 |
path_in_repo=repo_filename,
|
125 |
-
repo_id=
|
126 |
repo_type="dataset",
|
127 |
),
|
128 |
)
|
@@ -164,7 +121,7 @@ def _upload_feedback_to_gsheet_sync(
|
|
164 |
worksheet: Name of the worksheet to upload the feedback to.
|
165 |
"""
|
166 |
gs_connection = GSheetsServiceAccountClient(
|
167 |
-
ast.literal_eval(
|
168 |
)
|
169 |
df = gs_connection.read(worksheet=worksheet)
|
170 |
if df[df["id"] == row["id"]].empty:
|
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import sys
|
|
|
8 |
from datetime import timedelta
|
9 |
+
from typing import Dict
|
10 |
|
11 |
import openai
|
12 |
import pandas as pd
|
13 |
import streamlit as st
|
|
|
14 |
import yaml
|
15 |
from huggingface_hub import HfApi
|
16 |
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
|
|
|
20 |
from src.model.crs_model import CRSModel
|
21 |
|
22 |
# Initialize Hugging Face API
|
23 |
+
HF_API = HfApi(token=os.environ.get("hf_token"))
|
24 |
|
25 |
|
26 |
@st.cache_resource(
|
|
|
49 |
model_args = yaml.safe_load(open(model_config_file, "r"))
|
50 |
|
51 |
if "chatgpt" in model_name:
|
52 |
+
openai.api_key = os.environ.get("openai_api_key")
|
53 |
|
54 |
# Extract crs model from name
|
55 |
name = model_name.split("_")[0]
|
56 |
|
57 |
+
return CRSModel(name, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
async def upload_conversation_logs_to_hf(
|
|
|
79 |
lambda: HF_API.upload_file(
|
80 |
path_or_fileobj=conversation_log_file_path,
|
81 |
path_in_repo=repo_filename,
|
82 |
+
repo_id=os.environ.get("dataset_repo"),
|
83 |
repo_type="dataset",
|
84 |
),
|
85 |
)
|
|
|
121 |
worksheet: Name of the worksheet to upload the feedback to.
|
122 |
"""
|
123 |
gs_connection = GSheetsServiceAccountClient(
|
124 |
+
ast.literal_eval(os.environ.get("gsheet"))
|
125 |
)
|
126 |
df = gs_connection.read(worksheet=worksheet)
|
127 |
if df[df["id"] == row["id"]].empty:
|
download_external_data.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Script to download external data for the project at build time."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import tarfile
|
6 |
+
|
7 |
+
import wget
|
8 |
+
|
9 |
+
|
10 |
+
def download_and_extract_models() -> None:
|
11 |
+
"""Downloads the models folder from the server and extracts it."""
|
12 |
+
logging.debug("Downloading models folder.")
|
13 |
+
models_url = os.environ.get("MODELS_FOLDER_URL")
|
14 |
+
models_targz = "models.tar.gz"
|
15 |
+
models_folder = "data/models/"
|
16 |
+
try:
|
17 |
+
wget.download(models_url, models_targz)
|
18 |
+
|
19 |
+
logging.debug("Extracting models folder.")
|
20 |
+
with tarfile.open(models_targz, "r:gz") as tar:
|
21 |
+
tar.extractall(models_folder)
|
22 |
+
|
23 |
+
os.remove(models_targz)
|
24 |
+
logging.debug("Models folder downloaded and extracted.")
|
25 |
+
except Exception as e:
|
26 |
+
logging.error(f"Error downloading models folder: {e}")
|
27 |
+
|
28 |
+
|
29 |
+
def download_and_extract_item_embeddings() -> None:
|
30 |
+
"""Downloads the item embeddings folder from the server and extracts it."""
|
31 |
+
logging.debug("Downloading item embeddings folder.")
|
32 |
+
item_embeddings_url = os.environ.get("ITEM_EMBEDDINGS_URL")
|
33 |
+
item_embeddings_tarbz = "item_embeddings.tar.bz2"
|
34 |
+
item_embeddings_folder = "data/"
|
35 |
+
|
36 |
+
try:
|
37 |
+
wget.download(item_embeddings_url, item_embeddings_tarbz)
|
38 |
+
|
39 |
+
logging.debug("Extracting item embeddings folder.")
|
40 |
+
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
|
41 |
+
tar.extractall(item_embeddings_folder)
|
42 |
+
|
43 |
+
os.remove(item_embeddings_tarbz)
|
44 |
+
logging.debug("Item embeddings folder downloaded and extracted.")
|
45 |
+
except Exception as e:
|
46 |
+
logging.error(f"Error downloading item embeddings folder: {e}")
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
if not os.path.exists("data/models"):
|
51 |
+
logging.info("Downloading models...")
|
52 |
+
download_and_extract_models()
|
53 |
+
|
54 |
+
if not os.path.exists("data/embed_items"):
|
55 |
+
logging.info("Downloading item embeddings...")
|
56 |
+
download_and_extract_item_embeddings()
|
requirements.txt
CHANGED
@@ -10,7 +10,7 @@ tiktoken==0.7.0
|
|
10 |
tenacity<9.0.0
|
11 |
thefuzz==0.22.1
|
12 |
numpy<2
|
13 |
-
streamlit==1.
|
14 |
SQLAlchemy==1.4.0
|
15 |
sent2vec==0.3.0
|
16 |
wget==3.2
|
|
|
10 |
tenacity<9.0.0
|
11 |
thefuzz==0.22.1
|
12 |
numpy<2
|
13 |
+
streamlit==1.38.0
|
14 |
SQLAlchemy==1.4.0
|
15 |
sent2vec==0.3.0
|
16 |
wget==3.2
|