File size: 12,733 Bytes
b599481
 
 
 
 
 
 
6a3ff8c
0c71ade
 
 
 
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c84464e
dbb98e1
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c84464e
 
b599481
 
 
 
 
 
 
 
 
 
 
 
dbb98e1
0c71ade
 
dbb98e1
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a26986
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c38f0c
b599481
 
 
 
 
 
 
 
 
1c38f0c
 
b599481
 
 
 
 
 
 
 
 
 
1c38f0c
 
 
 
b599481
dbb98e1
d19bfb1
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb98e1
31b6d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c84464e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b6d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b599481
d0ff58f
b599481
0c71ade
 
 
 
 
 
 
 
 
 
 
 
 
810de64
 
 
b599481
 
 
 
 
 
 
 
 
 
 
 
 
d0ff58f
 
b599481
 
 
97c2180
b599481
ec1c3b2
0c71ade
 
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b6d03
b599481
 
 
31b6d03
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c71ade
a977fb9
 
0c71ade
 
b599481
d19bfb1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
"""Streamlit app for side-by-side battle of two CRSs.

The goal of this application is for a user to have a conversation with two
conversational recommender systems (CRSs) and vote on which one they prefer.
All the conversations are recorded and saved for future analysis. When the user
arrives on the app, they are assigned a unique ID. Then, two CRSs are chosen
for the battle depending on the number of conversations already recorded (i.e.,
the CRSs with the least number of conversations are chosen). The user interacts
with the CRSs one after the other. The user can then vote on which CRS they
prefer, upon voting a pop-up will appear giving the user the option to provide
a more detailed feedback. Once the vote is submitted, all data is logged and
the app restarts for a new battle.

The app is composed of four sections:
1. Title/Introduction
2. Rules
3. Side-by-Side Battle
4. Feedback
"""

import asyncio
import json
import logging
import os
from copy import deepcopy
from datetime import datetime
from typing import Dict, List

import streamlit as st
from battle_manager import (
    CONVERSATION_COUNTS,
    cache_fighters,
    get_crs_fighters,
    get_unique_user_id,
)
from crs_fighter import CRSFighter
from streamlit_lottie import st_lottie_spinner
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet

from src.model.crb_crs.recommender import *

# A message is a dictionary with two keys: role and message.
Message = Dict[str, str]

logging.basicConfig(
    level=logging.WARNING,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create the conversation logs directory
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)

TYPING_PLACEHOLDER_JSON = json.load(open("asset/Typing_animation.json", "r"))


# Callbacks
def record_vote(vote: str) -> None:
    """Record the user's vote in the database.

    Args:
        vote: Voted CRS model name.
    """
    user_id = st.session_state["user_id"]
    crs1_model: CRSFighter = st.session_state["crs1"]
    crs2_model: CRSFighter = st.session_state["crs2"]
    last_row_id = str(datetime.now())
    logger.info(
        f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}"
        f", {vote}"
    )
    asyncio.run(
        upload_feedback_to_gsheet(
            {
                "id": last_row_id,
                "user_id": user_id,
                "crs1": crs1_model.name,
                "crs2": crs2_model.name,
                "vote": vote,
            },
            worksheet="votes",
        )
    )
    feedback_dialog(row_id=last_row_id)


def record_feedback(feedback: str, row_id: int) -> None:
    """Record the user's feedback in the database and restart the app.

    Args:
        feedback: User's feedback.
        vote: Voted CRS model name.
        crs_models: Tuple of CRS model names.
        user_id: Unique user ID.
    """
    logger.info(f"Feedback: {row_id}, {feedback}")
    asyncio.run(
        upload_feedback_to_gsheet(
            {"id": row_id, "feedback": feedback}, "feedback"
        )
    )


def end_conversation(crs: CRSFighter, sentiment: str) -> None:
    """Ends the conversation with given CRS model.

    Records the conversation in the logs and moves either to the next CRS or
    to the voting section.

    Args:
        crs: CRS model.
        sentiment: User's sentiment (frustrated or satisfied).
    """
    messages: List[Message] = deepcopy(
        st.session_state[f"messages_{crs.fighter_id}"]
    )
    messages.append({"role": "metadata", "sentiment": sentiment})
    user_id = st.session_state["user_id"]
    logger.info(f"User {user_id} ended conversation with {crs.name}.")
    log_file_path = os.path.join(
        CONVERSATION_LOG_DIR, f"{user_id}_{crs.name}.json"
    )
    with open(log_file_path, "a") as f:
        json.dump(messages, f)

    # Update the conversation count
    CONVERSATION_COUNTS[crs.name] += 1

    # Asynchronously save the conversation logs to Hugging Face Hub
    asyncio.run(
        upload_conversation_logs_to_hf(
            log_file_path, f"conversation_logs/{user_id}_{crs.name}.json"
        )
    )

    if crs.fighter_id == 1:
        # Disable the chat interface for the first CRS
        st.session_state["crs1_enabled"] = False
        # Enable the chat interface for the second CRS
        st.session_state["crs2_enabled"] = True
    elif crs.fighter_id == 2:
        # Disable the chat interface for the second CRS
        st.session_state["crs2_enabled"] = False
        # Enable the voting section
        st.session_state["vote_enabled"] = True
        # Scroll to the voting section

    st.rerun()


def get_crs_response(crs: CRSFighter, message: str) -> str:
    """Gets the CRS response for the given message.

    This method sends a POST request to the CRS model including the history of
    the conversation and the user's message.

    Args:
        crs: CRS model.
        message: User's message.

    Returns:
        CRS response.
    """
    history = deepcopy(st.session_state[f"messages_{crs.fighter_id}"])
    history = history[:-1]  # Remove the last message (user's message)

    response, state = crs.reply(
        input_message=message,
        history=st.session_state[f"messages_{crs.fighter_id}"],
        options_state=st.session_state.get(f"state_{crs.fighter_id}", []),
    )
    st.session_state[f"state_{crs.fighter_id}"] = state
    # for word in response.split():
    #    yield f"{word} "
    #    time.sleep(0.05)
    return response


@st.dialog("Your vote has been submitted! Thank you!")
def feedback_dialog(row_id: int) -> None:
    """Pop-up dialog to provide feedback after voting.

    Feedback is optional and can be used to provide more detailed information
    about the user's vote.

    Args:
        row_id: Unique row ID of the vote.
    """
    feedback_text = st.text_area(
        "(Optional) You can provide more detailed feedback below:"
    )
    if st.button("Finish", use_container_width=True):
        record_feedback(feedback_text, row_id)
        # Restart the app
        st.session_state.clear()
        st.rerun()


@st.fragment
def chat_col(crs_id: int, color: str):
    """Chat column for the CRS model.

    Args:
        crs_id: CRS model ID (either 1 or 2).
        color: Color of the CRS model (red or large_blue).
    """
    with st.container(border=True):
        st.write(f":{color}_circle: CRS {crs_id}")
        # Display the chat history
        messages_crs = st.container(height=350, border=False)
        for message in st.session_state[f"messages_{crs_id}"]:
            messages_crs.chat_message(message["role"]).write(
                message["message"]
            )

        if prompt := st.chat_input(
            f"Send a message to CRS {crs_id}",
            key=f"prompt_crs{crs_id}",
            disabled=not st.session_state[f"crs{crs_id}_enabled"],
        ):
            # Display the user's message
            messages_crs.chat_message("user").write(prompt)
            # Add user's message to chat history
            st.session_state[f"messages_{crs_id}"].append(
                {"role": "user", "message": prompt}
            )

            crs_message = messages_crs.chat_message("assistant").empty()
            with crs_message:
                # Placeholder for the CRS response
                with st_lottie_spinner(
                    TYPING_PLACEHOLDER_JSON, height=40, width=40
                ):
                    response_crs = get_crs_response(
                        st.session_state[f"crs{crs_id}"], prompt
                    )
                    # Add CRS response to chat history
                    st.session_state[f"messages_{crs_id}"].append(
                        {"role": "assistant", "message": response_crs}
                    )
                # Display the CRS response
                crs_message.write(response_crs)

        frustrated_col, satisfied_col = st.columns(2)
        if frustrated_col.button(
            ":rage: Frustrated",
            use_container_width=True,
            key=f"end_frustated_crs{crs_id}",
            on_click=end_conversation,
            kwargs={
                "crs": st.session_state[f"crs{crs_id}"],
                "sentiment": "frustrated",
            },
            disabled=not st.session_state[f"crs{crs_id}_enabled"],
        ):
            st.rerun()
        if satisfied_col.button(
            ":heavy_check_mark: Satisfied",
            use_container_width=True,
            key=f"end_satisfied_crs{crs_id}",
            on_click=end_conversation,
            kwargs={
                "crs": st.session_state[f"crs{crs_id}"],
                "sentiment": "satisfied",
            },
            disabled=not st.session_state[f"crs{crs_id}_enabled"],
        ):
            st.rerun()


# Streamlit app
# st.set_page_config(page_title="CRS Arena", layout="wide")

# Add style for spinner
st.markdown(
    """<style>
            .stSpinner {
                font-size: 2em;
                 display: flex;
                 justify-content: center;
                 align-items: center;
            }
        </style>
    """,
    unsafe_allow_html=True,
)
# Cache some CRS fighters at startup
cache_fighters()

# Battle setup
if "user_id" not in st.session_state:
    st.session_state["user_id"] = get_unique_user_id()
    st.session_state["crs1"], st.session_state["crs2"] = get_crs_fighters()


# Introduction
st.title(":gun: CRS Arena")
st.write(
    "Welcome to the CRS Arena! Here you can have a conversation with two "
    "conversational recommender systems (CRSs) and vote on which one you "
    "prefer."
)
if st.button(":trophy: Leaderboard"):
    st.switch_page("leaderboard.py")

st.header(":page_with_curl: Rules")
st.write(
    "* Chat with each CRS (one after the other) to get **movie recommendations** "
    "up until you feel statisfied or frustrated.\n"
    "* Please be patient as some CRSs may take a few seconds to respond.\n"
    "* Try to send several messages to each CRS to get a better sense of their"
    " capabilities. Don't quit after the first message!\n"
    "* To finish chatting with a CRS, click on the button corresponding to "
    "your feeling: frustrated or satisfied.\n"
    "* Vote on which CRS you prefer or declare a tie.\n"
    "* (Optional) Provide more detailed feedback after voting.\n"
)

# Side-by-Side Battle
st.header(":point_down: Side-by-Side Battle")
st.write("Let's start the battle!")

# Initialize the chat messages
if "messages_1" not in st.session_state:
    st.session_state["messages_1"] = []
    st.session_state["state_1"] = None
if "messages_2" not in st.session_state:
    st.session_state["messages_2"] = []
    st.session_state["state_2"] = None

if "crs1_enabled" not in st.session_state:
    st.session_state["crs1_enabled"] = True
if "crs2_enabled" not in st.session_state:
    st.session_state["crs2_enabled"] = False
if "vote_enabled" not in st.session_state:
    st.session_state["vote_enabled"] = False

col_crs1, col_crs2 = st.columns(2)

# CRS 1
with col_crs1:
    chat_col(1, "red")

# CRS 2
with col_crs2:
    chat_col(2, "large_blue")

# Feedback section
container = st.container()
container.subheader(":trophy: Declare the winner!", anchor="vote")
container_col1, container_col2, container_col3 = container.columns(3)
container_col1.button(
    ":red_circle: CRS 1",
    use_container_width=True,
    key="crs1_wins",
    on_click=record_vote,
    kwargs={"vote": st.session_state["crs1"].name},
    disabled=not st.session_state["vote_enabled"],
)
container_col2.button(
    ":large_blue_circle: CRS 2",
    use_container_width=True,
    key="crs2_wins",
    on_click=record_vote,
    kwargs={"vote": st.session_state["crs2"].name},
    disabled=not st.session_state["vote_enabled"],
)
container_col3.button(
    "Tie",
    use_container_width=True,
    key="tie",
    on_click=record_vote,
    kwargs={"vote": "tie"},
    disabled=not st.session_state["vote_enabled"],
)

# Terms of service
st.header("Terms of Service")
st.write(
    "By using this application, you agree to the following terms of service:\n"
    "The service is a research platform. It may produce offensive content. "
    "Please do not upload any private information in the chat. The service "
    "collects the chat data and the user's vote, which may be released under a"
    " Creative Commons Attribution (CC-BY) or a similar license."
)

# Contact information
st.header("Contact Information")
st.write(
    "For any questions, concerns, feedback, or bug reports, please contact "
    "Nolwenn Bernard at <[email protected]>."
)