"""Streamlit app for side-by-side battle of two CRSs. |
The goal of this application is for a user to have a conversation with two |
conversational recommender systems (CRSs) and vote on which one they prefer. |
All the conversations are recorded and saved for future analysis. When the user |
arrives on the app, they are assigned a unique ID. Then, two CRSs are chosen |
for the battle depending on the number of conversations already recorded (i.e., |
the CRSs with the least number of conversations are chosen). The user interacts |
with the CRSs one after the other. The user can then vote on which CRS they |
prefer, upon voting a pop-up will appear giving the user the option to provide |
a more detailed feedback. Once the vote is submitted, all data is logged and |
the app restarts for a new battle. |
The app is composed of four sections: |
1. Title/Introduction |
2. Rules |
3. Side-by-Side Battle |
4. Feedback |
""" |
import asyncio |
import json |
import logging |
import os |
from copy import deepcopy |
from datetime import datetime |
from typing import Dict, List |
import streamlit as st |
from battle_manager import ( |
cache_fighters, |
get_crs_fighters, |
get_unique_user_id, |
) |
from crs_fighter import CRSFighter |
from streamlit_lottie import st_lottie_spinner |
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet |
from src.model.crb_crs.recommender import * |
Message = Dict[str, str] |
logging.basicConfig( |
level=logging.WARNING, |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
) |
logger = logging.getLogger(__name__) |
logger.setLevel(logging.INFO) |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/" |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True) |
TYPING_PLACEHOLDER_JSON = json.load(open("asset/Typing_animation.json", "r")) |
def record_vote(vote: str) -> None: |
"""Record the user's vote in the database. |
Args: |
vote: Voted CRS model name. |
""" |
user_id = st.session_state["user_id"] |
crs1_model: CRSFighter = st.session_state["crs1"] |
crs2_model: CRSFighter = st.session_state["crs2"] |
last_row_id = str(datetime.now()) |
logger.info( |
f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}" |
f", {vote}" |
) |
asyncio.run( |
upload_feedback_to_gsheet( |
{ |
"id": last_row_id, |
"user_id": user_id, |
"crs1": crs1_model.name, |
"crs2": crs2_model.name, |
"vote": vote, |
}, |
worksheet="votes", |
) |
) |
feedback_dialog(row_id=last_row_id) |
def record_feedback(feedback: str, row_id: int) -> None: |
"""Record the user's feedback in the database and restart the app. |
Args: |
feedback: User's feedback. |
vote: Voted CRS model name. |
crs_models: Tuple of CRS model names. |
user_id: Unique user ID. |
""" |
logger.info(f"Feedback: {row_id}, {feedback}") |
asyncio.run( |
upload_feedback_to_gsheet( |
{"id": row_id, "feedback": feedback}, "feedback" |
) |
) |
def end_conversation(crs: CRSFighter, sentiment: str) -> None: |
"""Ends the conversation with given CRS model. |
Records the conversation in the logs and moves either to the next CRS or |
to the voting section. |
Args: |
crs: CRS model. |
sentiment: User's sentiment (frustrated or satisfied). |
""" |
messages: List[Message] = deepcopy( |
st.session_state[f"messages_{crs.fighter_id}"] |
) |
messages.append({"role": "metadata", "sentiment": sentiment}) |
user_id = st.session_state["user_id"] |
logger.info(f"User {user_id} ended conversation with {crs.name}.") |
log_file_path = os.path.join( |
CONVERSATION_LOG_DIR, f"{user_id}_{crs.name}.json" |
) |
with open(log_file_path, "a") as f: |
json.dump(messages, f) |
CONVERSATION_COUNTS[crs.name] += 1 |
asyncio.run( |
upload_conversation_logs_to_hf( |
log_file_path, f"conversation_logs/{user_id}_{crs.name}.json" |
) |
) |
if crs.fighter_id == 1: |
st.session_state["crs1_enabled"] = False |
st.session_state["crs2_enabled"] = True |
elif crs.fighter_id == 2: |
st.session_state["crs2_enabled"] = False |
st.session_state["vote_enabled"] = True |
st.rerun() |
def get_crs_response(crs: CRSFighter, message: str) -> str: |
"""Gets the CRS response for the given message. |
This method sends a POST request to the CRS model including the history of |
the conversation and the user's message. |
Args: |
crs: CRS model. |
message: User's message. |
Returns: |
CRS response. |
""" |
history = deepcopy(st.session_state[f"messages_{crs.fighter_id}"]) |
history = history[:-1] |
response, state = crs.reply( |
input_message=message, |
history=st.session_state[f"messages_{crs.fighter_id}"], |
options_state=st.session_state.get(f"state_{crs.fighter_id}", []), |
) |
st.session_state[f"state_{crs.fighter_id}"] = state |
return response |
@st.dialog("Your vote has been submitted! Thank you!") |
def feedback_dialog(row_id: int) -> None: |
"""Pop-up dialog to provide feedback after voting. |
Feedback is optional and can be used to provide more detailed information |
about the user's vote. |
Args: |
row_id: Unique row ID of the vote. |
""" |
feedback_text = st.text_area( |
"(Optional) You can provide more detailed feedback below:" |
) |
if st.button("Finish", use_container_width=True): |
record_feedback(feedback_text, row_id) |
st.session_state.clear() |
st.rerun() |
@st.fragment |
def chat_col(crs_id: int, color: str): |
"""Chat column for the CRS model. |
Args: |
crs_id: CRS model ID (either 1 or 2). |
color: Color of the CRS model (red or large_blue). |
""" |
with st.container(border=True): |
st.write(f":{color}_circle: CRS {crs_id}") |
messages_crs = st.container(height=350, border=False) |
for message in st.session_state[f"messages_{crs_id}"]: |
messages_crs.chat_message(message["role"]).write( |
message["message"] |
) |
if prompt := st.chat_input( |
f"Send a message to CRS {crs_id}", |
key=f"prompt_crs{crs_id}", |
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
): |
messages_crs.chat_message("user").write(prompt) |
st.session_state[f"messages_{crs_id}"].append( |
{"role": "user", "message": prompt} |
) |
crs_message = messages_crs.chat_message("assistant").empty() |
with crs_message: |
with st_lottie_spinner( |
TYPING_PLACEHOLDER_JSON, height=40, width=40 |
): |
response_crs = get_crs_response( |
st.session_state[f"crs{crs_id}"], prompt |
) |
st.session_state[f"messages_{crs_id}"].append( |
{"role": "assistant", "message": response_crs} |
) |
crs_message.write(response_crs) |
frustrated_col, satisfied_col = st.columns(2) |
if frustrated_col.button( |
":rage: Frustrated", |
use_container_width=True, |
key=f"end_frustated_crs{crs_id}", |
on_click=end_conversation, |
kwargs={ |
"crs": st.session_state[f"crs{crs_id}"], |
"sentiment": "frustrated", |
}, |
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
): |
st.rerun() |
if satisfied_col.button( |
":heavy_check_mark: Satisfied", |
use_container_width=True, |
key=f"end_satisfied_crs{crs_id}", |
on_click=end_conversation, |
kwargs={ |
"crs": st.session_state[f"crs{crs_id}"], |
"sentiment": "satisfied", |
}, |
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
): |
st.rerun() |
st.set_page_config(page_title="CRS Arena", layout="wide") |
st.markdown( |
"""<style> |
.stSpinner { |
font-size: 2em; |
display: flex; |
justify-content: center; |
align-items: center; |
} |
</style> |
""", |
unsafe_allow_html=True, |
) |
cache_fighters() |
if "user_id" not in st.session_state: |
st.session_state["user_id"] = get_unique_user_id() |
st.session_state["crs1"], st.session_state["crs2"] = get_crs_fighters() |
st.title(":gun: CRS Arena") |
st.write( |
"Welcome to the CRS Arena! Here you can have a conversation with two " |
"conversational recommender systems (CRSs) and vote on which one you " |
"prefer." |
) |
st.header(":page_with_curl: Rules") |
st.write( |
"* Chat with each CRS (one after the other) to get **movie recommendations** " |
"up until you feel statisfied or frustrated.\n" |
"* Please be patient as some CRSs may take a few seconds to respond.\n" |
"* Try to send several messages to each CRS to get a better sense of their" |
" capabilities. Don't quit after the first message!\n" |
"* To finish chatting with a CRS, click on the button corresponding to " |
"your feeling: frustrated or satisfied.\n" |
"* Vote on which CRS you prefer or declare a tie.\n" |
"* (Optional) Provide more detailed feedback after voting.\n" |
) |
st.header(":point_down: Side-by-Side Battle") |
st.write("Let's start the battle!") |
if "messages_1" not in st.session_state: |
st.session_state["messages_1"] = [] |
st.session_state["state_1"] = None |
if "messages_2" not in st.session_state: |
st.session_state["messages_2"] = [] |
st.session_state["state_2"] = None |
if "crs1_enabled" not in st.session_state: |
st.session_state["crs1_enabled"] = True |
if "crs2_enabled" not in st.session_state: |
st.session_state["crs2_enabled"] = False |
if "vote_enabled" not in st.session_state: |
st.session_state["vote_enabled"] = False |
col_crs1, col_crs2 = st.columns(2) |
with col_crs1: |
chat_col(1, "red") |
with col_crs2: |
chat_col(2, "large_blue") |
container = st.container() |
container.subheader(":trophy: Declare the winner!", anchor="vote") |
container_col1, container_col2, container_col3 = container.columns(3) |
container_col1.button( |
":red_circle: CRS 1", |
use_container_width=True, |
key="crs1_wins", |
on_click=record_vote, |
kwargs={"vote": st.session_state["crs1"].name}, |
disabled=not st.session_state["vote_enabled"], |
) |
container_col2.button( |
":large_blue_circle: CRS 2", |
use_container_width=True, |
key="crs2_wins", |
on_click=record_vote, |
kwargs={"vote": st.session_state["crs2"].name}, |
disabled=not st.session_state["vote_enabled"], |
) |
container_col3.button( |
"Tie", |
use_container_width=True, |
key="tie", |
on_click=record_vote, |
kwargs={"vote": "tie"}, |
disabled=not st.session_state["vote_enabled"], |
) |
st.header("Terms of Service") |
st.write( |
"By using this application, you agree to the following terms of service:\n" |
"The service is a research platform. It may produce offensive content. " |
"Please do not upload any private information in the chat. The service " |
"collects the chat data and the user's vote, which may be released under a" |
" Creative Commons Attribution (CC-BY) or a similar license." |
) |
st.header("Contact Information") |
st.write( |
"For any questions, concerns, feedback, or bug reports, please contact " |
"Nolwenn Bernard at <[email protected]>." |
) |