Update crs_arena/arena.py
Browse files- crs_arena/arena.py +18 -8
crs_arena/arena.py
CHANGED
@@ -37,6 +37,7 @@ from battle_manager import (
|
|
37 |
get_unique_user_id,
|
38 |
)
|
39 |
from crs_fighter import CRSFighter
|
|
|
40 |
from utils import (
|
41 |
download_and_extract_item_embeddings,
|
42 |
download_and_extract_models,
|
@@ -68,6 +69,8 @@ if not os.path.exists("data/embed_items"):
|
|
68 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
69 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
70 |
|
|
|
|
|
71 |
|
72 |
# Callbacks
|
73 |
def record_vote(vote: str) -> None:
|
@@ -233,15 +236,22 @@ def chat_col(crs_id: int, color: str):
|
|
233 |
st.session_state[f"messages_{crs_id}"].append(
|
234 |
{"role": "user", "message": prompt}
|
235 |
)
|
236 |
-
# Get the CRS response
|
237 |
-
response_crs = messages_crs.chat_message("assistant").write_stream(
|
238 |
-
get_crs_response(st.session_state[f"crs{crs_id}"], prompt)
|
239 |
-
)
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
frustrated_col, satisfied_col = st.columns(2)
|
247 |
if frustrated_col.button(
|
|
|
37 |
get_unique_user_id,
|
38 |
)
|
39 |
from crs_fighter import CRSFighter
|
40 |
+
from streamlit_lottie import st_lottie_spinner
|
41 |
from utils import (
|
42 |
download_and_extract_item_embeddings,
|
43 |
download_and_extract_models,
|
|
|
69 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
70 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
71 |
|
72 |
+
TYPING_PLACEHOLDER_JSON = json.load(open("asset/Typing_animation.json", "r"))
|
73 |
+
|
74 |
|
75 |
# Callbacks
|
76 |
def record_vote(vote: str) -> None:
|
|
|
236 |
st.session_state[f"messages_{crs_id}"].append(
|
237 |
{"role": "user", "message": prompt}
|
238 |
)
|
|
|
|
|
|
|
|
|
239 |
|
240 |
+
crs_message = messages_crs.chat_message("assistant").empty()
|
241 |
+
with crs_message:
|
242 |
+
# Placeholder for the CRS response
|
243 |
+
with st_lottie_spinner(
|
244 |
+
TYPING_PLACEHOLDER_JSON, height=40, width=40
|
245 |
+
):
|
246 |
+
response_crs = get_crs_response(
|
247 |
+
st.session_state[f"crs{crs_id}"], prompt
|
248 |
+
)
|
249 |
+
# Add CRS response to chat history
|
250 |
+
st.session_state[f"messages_{crs_id}"].append(
|
251 |
+
{"role": "assistant", "message": response_crs}
|
252 |
+
)
|
253 |
+
# Display the CRS response
|
254 |
+
crs_message.write(response_crs)
|
255 |
|
256 |
frustrated_col, satisfied_col = st.columns(2)
|
257 |
if frustrated_col.button(
|