File size: 12,733 Bytes
b599481 6a3ff8c 0c71ade b599481 c84464e dbb98e1 b599481 c84464e b599481 dbb98e1 0c71ade dbb98e1 b599481 3a26986 b599481 1c38f0c b599481 1c38f0c b599481 1c38f0c b599481 dbb98e1 d19bfb1 b599481 dbb98e1 31b6d03 c84464e 31b6d03 b599481 d0ff58f b599481 0c71ade 810de64 b599481 d0ff58f b599481 97c2180 b599481 ec1c3b2 0c71ade b599481 31b6d03 b599481 31b6d03 b599481 0c71ade a977fb9 0c71ade b599481 d19bfb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 |
"""Streamlit app for side-by-side battle of two CRSs.
The goal of this application is for a user to have a conversation with two
conversational recommender systems (CRSs) and vote on which one they prefer.
All the conversations are recorded and saved for future analysis. When the user
arrives on the app, they are assigned a unique ID. Then, two CRSs are chosen
for the battle depending on the number of conversations already recorded (i.e.,
the CRSs with the least number of conversations are chosen). The user interacts
with the CRSs one after the other. The user can then vote on which CRS they
prefer, upon voting a pop-up will appear giving the user the option to provide
a more detailed feedback. Once the vote is submitted, all data is logged and
the app restarts for a new battle.
The app is composed of four sections:
1. Title/Introduction
2. Rules
3. Side-by-Side Battle
4. Feedback
"""
import asyncio
import json
import logging
import os
from copy import deepcopy
from datetime import datetime
from typing import Dict, List
import streamlit as st
from battle_manager import (
CONVERSATION_COUNTS,
cache_fighters,
get_crs_fighters,
get_unique_user_id,
)
from crs_fighter import CRSFighter
from streamlit_lottie import st_lottie_spinner
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet
from src.model.crb_crs.recommender import *
# A message is a dictionary with two keys: role and message.
Message = Dict[str, str]
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Create the conversation logs directory
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
TYPING_PLACEHOLDER_JSON = json.load(open("asset/Typing_animation.json", "r"))
# Callbacks
def record_vote(vote: str) -> None:
"""Record the user's vote in the database.
Args:
vote: Voted CRS model name.
"""
user_id = st.session_state["user_id"]
crs1_model: CRSFighter = st.session_state["crs1"]
crs2_model: CRSFighter = st.session_state["crs2"]
last_row_id = str(datetime.now())
logger.info(
f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}"
f", {vote}"
)
asyncio.run(
upload_feedback_to_gsheet(
{
"id": last_row_id,
"user_id": user_id,
"crs1": crs1_model.name,
"crs2": crs2_model.name,
"vote": vote,
},
worksheet="votes",
)
)
feedback_dialog(row_id=last_row_id)
def record_feedback(feedback: str, row_id: int) -> None:
"""Record the user's feedback in the database and restart the app.
Args:
feedback: User's feedback.
vote: Voted CRS model name.
crs_models: Tuple of CRS model names.
user_id: Unique user ID.
"""
logger.info(f"Feedback: {row_id}, {feedback}")
asyncio.run(
upload_feedback_to_gsheet(
{"id": row_id, "feedback": feedback}, "feedback"
)
)
def end_conversation(crs: CRSFighter, sentiment: str) -> None:
"""Ends the conversation with given CRS model.
Records the conversation in the logs and moves either to the next CRS or
to the voting section.
Args:
crs: CRS model.
sentiment: User's sentiment (frustrated or satisfied).
"""
messages: List[Message] = deepcopy(
st.session_state[f"messages_{crs.fighter_id}"]
)
messages.append({"role": "metadata", "sentiment": sentiment})
user_id = st.session_state["user_id"]
logger.info(f"User {user_id} ended conversation with {crs.name}.")
log_file_path = os.path.join(
CONVERSATION_LOG_DIR, f"{user_id}_{crs.name}.json"
)
with open(log_file_path, "a") as f:
json.dump(messages, f)
# Update the conversation count
CONVERSATION_COUNTS[crs.name] += 1
# Asynchronously save the conversation logs to Hugging Face Hub
asyncio.run(
upload_conversation_logs_to_hf(
log_file_path, f"conversation_logs/{user_id}_{crs.name}.json"
)
)
if crs.fighter_id == 1:
# Disable the chat interface for the first CRS
st.session_state["crs1_enabled"] = False
# Enable the chat interface for the second CRS
st.session_state["crs2_enabled"] = True
elif crs.fighter_id == 2:
# Disable the chat interface for the second CRS
st.session_state["crs2_enabled"] = False
# Enable the voting section
st.session_state["vote_enabled"] = True
# Scroll to the voting section
st.rerun()
def get_crs_response(crs: CRSFighter, message: str) -> str:
"""Gets the CRS response for the given message.
This method sends a POST request to the CRS model including the history of
the conversation and the user's message.
Args:
crs: CRS model.
message: User's message.
Returns:
CRS response.
"""
history = deepcopy(st.session_state[f"messages_{crs.fighter_id}"])
history = history[:-1] # Remove the last message (user's message)
response, state = crs.reply(
input_message=message,
history=st.session_state[f"messages_{crs.fighter_id}"],
options_state=st.session_state.get(f"state_{crs.fighter_id}", []),
)
st.session_state[f"state_{crs.fighter_id}"] = state
# for word in response.split():
# yield f"{word} "
# time.sleep(0.05)
return response
@st.dialog("Your vote has been submitted! Thank you!")
def feedback_dialog(row_id: int) -> None:
"""Pop-up dialog to provide feedback after voting.
Feedback is optional and can be used to provide more detailed information
about the user's vote.
Args:
row_id: Unique row ID of the vote.
"""
feedback_text = st.text_area(
"(Optional) You can provide more detailed feedback below:"
)
if st.button("Finish", use_container_width=True):
record_feedback(feedback_text, row_id)
# Restart the app
st.session_state.clear()
st.rerun()
@st.fragment
def chat_col(crs_id: int, color: str):
"""Chat column for the CRS model.
Args:
crs_id: CRS model ID (either 1 or 2).
color: Color of the CRS model (red or large_blue).
"""
with st.container(border=True):
st.write(f":{color}_circle: CRS {crs_id}")
# Display the chat history
messages_crs = st.container(height=350, border=False)
for message in st.session_state[f"messages_{crs_id}"]:
messages_crs.chat_message(message["role"]).write(
message["message"]
)
if prompt := st.chat_input(
f"Send a message to CRS {crs_id}",
key=f"prompt_crs{crs_id}",
disabled=not st.session_state[f"crs{crs_id}_enabled"],
):
# Display the user's message
messages_crs.chat_message("user").write(prompt)
# Add user's message to chat history
st.session_state[f"messages_{crs_id}"].append(
{"role": "user", "message": prompt}
)
crs_message = messages_crs.chat_message("assistant").empty()
with crs_message:
# Placeholder for the CRS response
with st_lottie_spinner(
TYPING_PLACEHOLDER_JSON, height=40, width=40
):
response_crs = get_crs_response(
st.session_state[f"crs{crs_id}"], prompt
)
# Add CRS response to chat history
st.session_state[f"messages_{crs_id}"].append(
{"role": "assistant", "message": response_crs}
)
# Display the CRS response
crs_message.write(response_crs)
frustrated_col, satisfied_col = st.columns(2)
if frustrated_col.button(
":rage: Frustrated",
use_container_width=True,
key=f"end_frustated_crs{crs_id}",
on_click=end_conversation,
kwargs={
"crs": st.session_state[f"crs{crs_id}"],
"sentiment": "frustrated",
},
disabled=not st.session_state[f"crs{crs_id}_enabled"],
):
st.rerun()
if satisfied_col.button(
":heavy_check_mark: Satisfied",
use_container_width=True,
key=f"end_satisfied_crs{crs_id}",
on_click=end_conversation,
kwargs={
"crs": st.session_state[f"crs{crs_id}"],
"sentiment": "satisfied",
},
disabled=not st.session_state[f"crs{crs_id}_enabled"],
):
st.rerun()
# Streamlit app
# st.set_page_config(page_title="CRS Arena", layout="wide")
# Add style for spinner
st.markdown(
"""<style>
.stSpinner {
font-size: 2em;
display: flex;
justify-content: center;
align-items: center;
}
</style>
""",
unsafe_allow_html=True,
)
# Cache some CRS fighters at startup
cache_fighters()
# Battle setup
if "user_id" not in st.session_state:
st.session_state["user_id"] = get_unique_user_id()
st.session_state["crs1"], st.session_state["crs2"] = get_crs_fighters()
# Introduction
st.title(":gun: CRS Arena")
st.write(
"Welcome to the CRS Arena! Here you can have a conversation with two "
"conversational recommender systems (CRSs) and vote on which one you "
"prefer."
)
if st.button(":trophy: Leaderboard"):
st.switch_page("leaderboard.py")
st.header(":page_with_curl: Rules")
st.write(
"* Chat with each CRS (one after the other) to get **movie recommendations** "
"up until you feel statisfied or frustrated.\n"
"* Please be patient as some CRSs may take a few seconds to respond.\n"
"* Try to send several messages to each CRS to get a better sense of their"
" capabilities. Don't quit after the first message!\n"
"* To finish chatting with a CRS, click on the button corresponding to "
"your feeling: frustrated or satisfied.\n"
"* Vote on which CRS you prefer or declare a tie.\n"
"* (Optional) Provide more detailed feedback after voting.\n"
)
# Side-by-Side Battle
st.header(":point_down: Side-by-Side Battle")
st.write("Let's start the battle!")
# Initialize the chat messages
if "messages_1" not in st.session_state:
st.session_state["messages_1"] = []
st.session_state["state_1"] = None
if "messages_2" not in st.session_state:
st.session_state["messages_2"] = []
st.session_state["state_2"] = None
if "crs1_enabled" not in st.session_state:
st.session_state["crs1_enabled"] = True
if "crs2_enabled" not in st.session_state:
st.session_state["crs2_enabled"] = False
if "vote_enabled" not in st.session_state:
st.session_state["vote_enabled"] = False
col_crs1, col_crs2 = st.columns(2)
# CRS 1
with col_crs1:
chat_col(1, "red")
# CRS 2
with col_crs2:
chat_col(2, "large_blue")
# Feedback section
container = st.container()
container.subheader(":trophy: Declare the winner!", anchor="vote")
container_col1, container_col2, container_col3 = container.columns(3)
container_col1.button(
":red_circle: CRS 1",
use_container_width=True,
key="crs1_wins",
on_click=record_vote,
kwargs={"vote": st.session_state["crs1"].name},
disabled=not st.session_state["vote_enabled"],
)
container_col2.button(
":large_blue_circle: CRS 2",
use_container_width=True,
key="crs2_wins",
on_click=record_vote,
kwargs={"vote": st.session_state["crs2"].name},
disabled=not st.session_state["vote_enabled"],
)
container_col3.button(
"Tie",
use_container_width=True,
key="tie",
on_click=record_vote,
kwargs={"vote": "tie"},
disabled=not st.session_state["vote_enabled"],
)
# Terms of service
st.header("Terms of Service")
st.write(
"By using this application, you agree to the following terms of service:\n"
"The service is a research platform. It may produce offensive content. "
"Please do not upload any private information in the chat. The service "
"collects the chat data and the user's vote, which may be released under a"
" Creative Commons Attribution (CC-BY) or a similar license."
)
# Contact information
st.header("Contact Information")
st.write(
"For any questions, concerns, feedback, or bug reports, please contact "
"Nolwenn Bernard at <[email protected]>."
)
|