ahuang11's picture
Update app.py
13c31d1 verified
import os
import asyncio
import random
import sqlite3
import panel as pn
import pandas as pd
from litellm import acompletion
pn.extension("perspective")
MODELS = [
"mistral/mistral-tiny",
"mistral/mistral-small",
"mistral/mistral-medium",
"mistral/mistral-large-latest",
]
VOTING_LABELS = [
"πŸ‘ˆ A is better",
"πŸ€— About the same",
"πŸ˜“ Both not good",
"πŸ‘‰ B is better",
]
async def respond(content, user, instance):
"""
Respond to the user in the chat interface.
"""
try:
instance.disabled = True
chat_label = instance.name
if chat_model := chat_models.get(chat_label):
model = chat_model
else:
# remove past history up to new message
instance.objects = instance.objects[-1:]
header_a.object = f"## Model: A"
header_b.object = f"## Model: B"
model = chat_models[chat_label] = random.choice(MODELS)
messages = instance.serialize()
messages.append({"role": "user", "content": content})
if api_key_input.value:
api_key = api_key_input.value
else:
api_key = os.environ.get("MISTRAL_API_KEY")
response = await acompletion(
model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key
)
message = None
async for chunk in response:
if not chunk.choices[0].delta["content"]:
continue
message = instance.stream(
chunk.choices[0].delta["content"], user="Assistant", message=message
)
finally:
instance.disabled = False
async def forward_message(content, user, instance):
"""
Send the message to the other chat interface and respond to the user in both.
"""
if instance is chat_interface_a:
other_instance = chat_interface_b
else:
other_instance = chat_interface_a
other_instance.append(pn.chat.ChatMessage(content, user=user))
coroutines = [
respond(content, user, chat_interface)
for chat_interface in (chat_interface_a, chat_interface_b)
]
await asyncio.gather(*coroutines)
def click_vote(event):
"""
Count the votes and update the voting results.
"""
if len(chat_models) == 0:
return
voting_label = event.obj.name
if voting_label == VOTING_LABELS[0]:
chat_model = chat_models[chat_interface_a.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[3]:
chat_model = chat_models[chat_interface_b.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[1]:
chat_model_a = chat_models[chat_interface_a.name]
chat_model_b = chat_models[chat_interface_b.name]
if chat_model_a == chat_model_b:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
else:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1
header_a.object = f"## Model: {chat_models[chat_interface_a.name]}"
header_b.object = f"## Model: {chat_models[chat_interface_b.name]}"
for chat_label in set(chat_models.keys()):
chat_models.pop(chat_label)
perspective.object = (
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model")
)
with sqlite3.connect("voting_counts.db") as conn:
pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql(
"voting_counts", conn, if_exists="replace", index=False
)
# initialize
chat_models = {}
with sqlite3.connect("voting_counts.db") as conn:
conn.execute(
"CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)"
)
voting_counts = (
pd.read_sql("SELECT * FROM voting_counts", conn)
.set_index("Model")["Votes"]
.to_dict()
)
# header
api_key_input = pn.widgets.PasswordInput(
placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"]
)
# main
tabs = pn.Tabs()
# tab 1
chat_interface_kwargs = dict(
callback=forward_message,
show_undo=False,
show_rerun=False,
show_clear=False,
show_stop=False,
show_button_name=False,
)
header_a = pn.pane.Markdown("## Model: A")
chat_interface_a = pn.chat.ChatInterface(
name="A", header=header_a, **chat_interface_kwargs
)
header_b = pn.pane.Markdown("## Model: B")
chat_interface_b = pn.chat.ChatInterface(
name="B", header=header_b, **chat_interface_kwargs
)
button_kwargs = dict(sizing_mode="stretch_width")
button_row = pn.Row()
for voting_label in VOTING_LABELS:
button = pn.widgets.Button(name=voting_label, **button_kwargs)
button.on_click(click_vote)
button_row.append(button)
tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row)))
# tab 2
perspective = pn.pane.Perspective(
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model"),
sizing_mode="stretch_both",
editable=False,
)
tabs.append(("Voting Results", perspective))
# layout
pn.template.FastListTemplate(
title="Mistral Chat Arena",
header=[api_key_input],
main=[tabs],
).servable()