|
"""Streamlit app for side-by-side battle of two CRSs. |
|
|
|
The goal of this application is for a user to have a conversation with two |
|
conversational recommender systems (CRSs) and vote on which one they prefer. |
|
All the conversations are recorded and saved for future analysis. When the user |
|
arrives on the app, they are assigned a unique ID. Then, two CRSs are chosen |
|
for the battle depending on the number of conversations already recorded (i.e., |
|
the CRSs with the least number of conversations are chosen). The user interacts |
|
with the CRSs one after the other. The user can then vote on which CRS they |
|
prefer, upon voting a pop-up will appear giving the user the option to provide |
|
a more detailed feedback. Once the vote is submitted, all data is logged and |
|
the app restarts for a new battle. |
|
|
|
The app is composed of four sections: |
|
1. Title/Introduction |
|
2. Rules |
|
3. Side-by-Side Battle |
|
4. Feedback |
|
""" |
|
|
|
import asyncio |
|
import json |
|
import logging |
|
import os |
|
from copy import deepcopy |
|
from datetime import datetime |
|
from typing import Dict, List |
|
|
|
import streamlit as st |
|
from battle_manager import ( |
|
CONVERSATION_COUNTS, |
|
cache_fighters, |
|
get_crs_fighters, |
|
get_unique_user_id, |
|
) |
|
from crs_fighter import CRSFighter |
|
from streamlit_lottie import st_lottie_spinner |
|
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet |
|
|
|
from src.model.crb_crs.recommender import * |
|
|
|
|
|
Message = Dict[str, str] |
|
|
|
logging.basicConfig( |
|
level=logging.WARNING, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/" |
|
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True) |
|
|
|
TYPING_PLACEHOLDER_JSON = json.load(open("asset/Typing_animation.json", "r")) |
|
|
|
|
|
|
|
def record_vote(vote: str) -> None: |
|
"""Record the user's vote in the database. |
|
|
|
Args: |
|
vote: Voted CRS model name. |
|
""" |
|
user_id = st.session_state["user_id"] |
|
crs1_model: CRSFighter = st.session_state["crs1"] |
|
crs2_model: CRSFighter = st.session_state["crs2"] |
|
last_row_id = str(datetime.now()) |
|
logger.info( |
|
f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}" |
|
f", {vote}" |
|
) |
|
asyncio.run( |
|
upload_feedback_to_gsheet( |
|
{ |
|
"id": last_row_id, |
|
"user_id": user_id, |
|
"crs1": crs1_model.name, |
|
"crs2": crs2_model.name, |
|
"vote": vote, |
|
}, |
|
worksheet="votes", |
|
) |
|
) |
|
feedback_dialog(row_id=last_row_id) |
|
|
|
|
|
def record_feedback(feedback: str, row_id: int) -> None: |
|
"""Record the user's feedback in the database and restart the app. |
|
|
|
Args: |
|
feedback: User's feedback. |
|
vote: Voted CRS model name. |
|
crs_models: Tuple of CRS model names. |
|
user_id: Unique user ID. |
|
""" |
|
logger.info(f"Feedback: {row_id}, {feedback}") |
|
asyncio.run( |
|
upload_feedback_to_gsheet( |
|
{"id": row_id, "feedback": feedback}, "feedback" |
|
) |
|
) |
|
|
|
|
|
def end_conversation(crs: CRSFighter, sentiment: str) -> None: |
|
"""Ends the conversation with given CRS model. |
|
|
|
Records the conversation in the logs and moves either to the next CRS or |
|
to the voting section. |
|
|
|
Args: |
|
crs: CRS model. |
|
sentiment: User's sentiment (frustrated or satisfied). |
|
""" |
|
messages: List[Message] = deepcopy( |
|
st.session_state[f"messages_{crs.fighter_id}"] |
|
) |
|
messages.append({"role": "metadata", "sentiment": sentiment}) |
|
user_id = st.session_state["user_id"] |
|
logger.info(f"User {user_id} ended conversation with {crs.name}.") |
|
log_file_path = os.path.join( |
|
CONVERSATION_LOG_DIR, f"{user_id}_{crs.name}.json" |
|
) |
|
with open(log_file_path, "a") as f: |
|
json.dump(messages, f) |
|
|
|
|
|
CONVERSATION_COUNTS[crs.name] += 1 |
|
|
|
|
|
asyncio.run( |
|
upload_conversation_logs_to_hf( |
|
log_file_path, f"conversation_logs/{user_id}_{crs.name}.json" |
|
) |
|
) |
|
|
|
if crs.fighter_id == 1: |
|
|
|
st.session_state["crs1_enabled"] = False |
|
|
|
st.session_state["crs2_enabled"] = True |
|
elif crs.fighter_id == 2: |
|
|
|
st.session_state["crs2_enabled"] = False |
|
|
|
st.session_state["vote_enabled"] = True |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
def get_crs_response(crs: CRSFighter, message: str) -> str: |
|
"""Gets the CRS response for the given message. |
|
|
|
This method sends a POST request to the CRS model including the history of |
|
the conversation and the user's message. |
|
|
|
Args: |
|
crs: CRS model. |
|
message: User's message. |
|
|
|
Returns: |
|
CRS response. |
|
""" |
|
history = deepcopy(st.session_state[f"messages_{crs.fighter_id}"]) |
|
history = history[:-1] |
|
|
|
response, state = crs.reply( |
|
input_message=message, |
|
history=st.session_state[f"messages_{crs.fighter_id}"], |
|
options_state=st.session_state.get(f"state_{crs.fighter_id}", []), |
|
) |
|
st.session_state[f"state_{crs.fighter_id}"] = state |
|
|
|
|
|
|
|
return response |
|
|
|
|
|
@st.dialog("Your vote has been submitted! Thank you!") |
|
def feedback_dialog(row_id: int) -> None: |
|
"""Pop-up dialog to provide feedback after voting. |
|
|
|
Feedback is optional and can be used to provide more detailed information |
|
about the user's vote. |
|
|
|
Args: |
|
row_id: Unique row ID of the vote. |
|
""" |
|
feedback_text = st.text_area( |
|
"(Optional) You can provide more detailed feedback below:" |
|
) |
|
if st.button("Finish", use_container_width=True): |
|
record_feedback(feedback_text, row_id) |
|
|
|
st.session_state.clear() |
|
st.rerun() |
|
|
|
|
|
@st.fragment |
|
def chat_col(crs_id: int, color: str): |
|
"""Chat column for the CRS model. |
|
|
|
Args: |
|
crs_id: CRS model ID (either 1 or 2). |
|
color: Color of the CRS model (red or large_blue). |
|
""" |
|
with st.container(border=True): |
|
st.write(f":{color}_circle: CRS {crs_id}") |
|
|
|
messages_crs = st.container(height=350, border=False) |
|
for message in st.session_state[f"messages_{crs_id}"]: |
|
messages_crs.chat_message(message["role"]).write( |
|
message["message"] |
|
) |
|
|
|
if prompt := st.chat_input( |
|
f"Send a message to CRS {crs_id}", |
|
key=f"prompt_crs{crs_id}", |
|
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
|
): |
|
|
|
messages_crs.chat_message("user").write(prompt) |
|
|
|
st.session_state[f"messages_{crs_id}"].append( |
|
{"role": "user", "message": prompt} |
|
) |
|
|
|
crs_message = messages_crs.chat_message("assistant").empty() |
|
with crs_message: |
|
|
|
with st_lottie_spinner( |
|
TYPING_PLACEHOLDER_JSON, height=40, width=40 |
|
): |
|
response_crs = get_crs_response( |
|
st.session_state[f"crs{crs_id}"], prompt |
|
) |
|
|
|
st.session_state[f"messages_{crs_id}"].append( |
|
{"role": "assistant", "message": response_crs} |
|
) |
|
|
|
crs_message.write(response_crs) |
|
|
|
frustrated_col, satisfied_col = st.columns(2) |
|
if frustrated_col.button( |
|
":rage: Frustrated", |
|
use_container_width=True, |
|
key=f"end_frustated_crs{crs_id}", |
|
on_click=end_conversation, |
|
kwargs={ |
|
"crs": st.session_state[f"crs{crs_id}"], |
|
"sentiment": "frustrated", |
|
}, |
|
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
|
): |
|
st.rerun() |
|
if satisfied_col.button( |
|
":heavy_check_mark: Satisfied", |
|
use_container_width=True, |
|
key=f"end_satisfied_crs{crs_id}", |
|
on_click=end_conversation, |
|
kwargs={ |
|
"crs": st.session_state[f"crs{crs_id}"], |
|
"sentiment": "satisfied", |
|
}, |
|
disabled=not st.session_state[f"crs{crs_id}_enabled"], |
|
): |
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown( |
|
"""<style> |
|
.stSpinner { |
|
font-size: 2em; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
cache_fighters() |
|
|
|
|
|
if "user_id" not in st.session_state: |
|
st.session_state["user_id"] = get_unique_user_id() |
|
st.session_state["crs1"], st.session_state["crs2"] = get_crs_fighters() |
|
|
|
|
|
|
|
st.title(":gun: CRS Arena") |
|
st.write( |
|
"Welcome to the CRS Arena! Here you can have a conversation with two " |
|
"conversational recommender systems (CRSs) and vote on which one you " |
|
"prefer." |
|
) |
|
if st.button(":trophy: Leaderboard"): |
|
st.switch_page("leaderboard.py") |
|
|
|
st.header(":page_with_curl: Rules") |
|
st.write( |
|
"* Chat with each CRS (one after the other) to get **movie recommendations** " |
|
"up until you feel statisfied or frustrated.\n" |
|
"* Please be patient as some CRSs may take a few seconds to respond.\n" |
|
"* Try to send several messages to each CRS to get a better sense of their" |
|
" capabilities. Don't quit after the first message!\n" |
|
"* To finish chatting with a CRS, click on the button corresponding to " |
|
"your feeling: frustrated or satisfied.\n" |
|
"* Vote on which CRS you prefer or declare a tie.\n" |
|
"* (Optional) Provide more detailed feedback after voting.\n" |
|
) |
|
|
|
|
|
st.header(":point_down: Side-by-Side Battle") |
|
st.write("Let's start the battle!") |
|
|
|
|
|
if "messages_1" not in st.session_state: |
|
st.session_state["messages_1"] = [] |
|
st.session_state["state_1"] = None |
|
if "messages_2" not in st.session_state: |
|
st.session_state["messages_2"] = [] |
|
st.session_state["state_2"] = None |
|
|
|
if "crs1_enabled" not in st.session_state: |
|
st.session_state["crs1_enabled"] = True |
|
if "crs2_enabled" not in st.session_state: |
|
st.session_state["crs2_enabled"] = False |
|
if "vote_enabled" not in st.session_state: |
|
st.session_state["vote_enabled"] = False |
|
|
|
col_crs1, col_crs2 = st.columns(2) |
|
|
|
|
|
with col_crs1: |
|
chat_col(1, "red") |
|
|
|
|
|
with col_crs2: |
|
chat_col(2, "large_blue") |
|
|
|
|
|
container = st.container() |
|
container.subheader(":trophy: Declare the winner!", anchor="vote") |
|
container_col1, container_col2, container_col3 = container.columns(3) |
|
container_col1.button( |
|
":red_circle: CRS 1", |
|
use_container_width=True, |
|
key="crs1_wins", |
|
on_click=record_vote, |
|
kwargs={"vote": st.session_state["crs1"].name}, |
|
disabled=not st.session_state["vote_enabled"], |
|
) |
|
container_col2.button( |
|
":large_blue_circle: CRS 2", |
|
use_container_width=True, |
|
key="crs2_wins", |
|
on_click=record_vote, |
|
kwargs={"vote": st.session_state["crs2"].name}, |
|
disabled=not st.session_state["vote_enabled"], |
|
) |
|
container_col3.button( |
|
"Tie", |
|
use_container_width=True, |
|
key="tie", |
|
on_click=record_vote, |
|
kwargs={"vote": "tie"}, |
|
disabled=not st.session_state["vote_enabled"], |
|
) |
|
|
|
|
|
st.header("Terms of Service") |
|
st.write( |
|
"By using this application, you agree to the following terms of service:\n" |
|
"The service is a research platform. It may produce offensive content. " |
|
"Please do not upload any private information in the chat. The service " |
|
"collects the chat data and the user's vote, which may be released under a" |
|
" Creative Commons Attribution (CC-BY) or a similar license." |
|
) |
|
|
|
|
|
st.header("Contact Information") |
|
st.write( |
|
"For any questions, concerns, feedback, or bug reports, please contact " |
|
"Nolwenn Bernard at <[email protected]>." |
|
) |
|
|