|
import argparse |
|
from collections import defaultdict |
|
import datetime |
|
import json |
|
import os |
|
import time |
|
import uuid |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
|
|
from fastchat.conversation import get_default_conv_template, SeparatorStyle |
|
from fastchat.constants import LOGDIR |
|
from fastchat.utils import ( |
|
build_logger, |
|
server_error_msg, |
|
violates_moderation, |
|
moderation_msg, |
|
) |
|
from fastchat.serve.gradio_patch import Chatbot as grChatbot |
|
from fastchat.serve.gradio_web_server import ( |
|
http_bot, |
|
set_global_vars, |
|
get_window_url_params, |
|
get_conv_log_filename, |
|
block_css, |
|
build_single_model_ui, |
|
no_change_btn, |
|
enable_btn, |
|
disable_btn, |
|
get_model_list, |
|
load_demo_single, |
|
) |
|
from fastchat.serve.inference import compute_skip_echo_len |
|
|
|
|
|
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") |
|
|
|
num_models = 2 |
|
|
|
|
|
def load_demo_side_by_side(url_params): |
|
states = (None,) * num_models |
|
|
|
model_left = models[0] |
|
if len(models) > 1: |
|
weights = ([8, 4, 2, 1] + [1] * 32)[:len(models) - 1] |
|
weights = weights / np.sum(weights) |
|
model_right = np.random.choice(models[1:], p=weights) |
|
else: |
|
model_right = model_left |
|
|
|
dropdown_updates = ( |
|
gr.Dropdown.update(model_left, visible=True), |
|
gr.Dropdown.update(model_right, visible=True), |
|
) |
|
|
|
return ( |
|
states |
|
+ dropdown_updates |
|
+ (gr.Chatbot.update(visible=True),) * num_models |
|
+ ( |
|
gr.Textbox.update(visible=True), |
|
gr.Box.update(visible=True), |
|
gr.Row.update(visible=True), |
|
gr.Row.update(visible=True), |
|
gr.Accordion.update(visible=True), |
|
) |
|
) |
|
|
|
|
|
def load_demo(url_params, request: gr.Request): |
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") |
|
selected = 0 |
|
if "arena" in url_params or "compare" in url_params: |
|
selected = 1 |
|
single_updates = load_demo_single(url_params) |
|
side_by_side_updates = load_demo_side_by_side(url_params) |
|
return (gr.Tabs.update(selected=selected),) + single_updates + side_by_side_updates |
|
|
|
|
|
def vote_last_response(states, vote_type, model_selectors, request: gr.Request): |
|
with open(get_conv_log_filename(), "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"models": [x for x in model_selectors], |
|
"states": [x.dict() for x in states], |
|
"ip": request.client.host, |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
def leftvote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"leftvote. ip: {request.client.host}") |
|
vote_last_response( |
|
[state0, state1], "leftvote", [model_selector0, model_selector1], request |
|
) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def rightvote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"rightvote. ip: {request.client.host}") |
|
vote_last_response( |
|
[state0, state1], "rightvote", [model_selector0, model_selector1], request |
|
) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def tievote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"tievote. ip: {request.client.host}") |
|
vote_last_response( |
|
[state0, state1], "tievote", [model_selector0, model_selector1], request |
|
) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def regenerate(state0, state1, request: gr.Request): |
|
logger.info(f"regenerate. ip: {request.client.host}") |
|
states = [state0, state1] |
|
for i in range(num_models): |
|
states[i].messages[-1][-1] = None |
|
states[i].skip_next = False |
|
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 5 |
|
|
|
|
|
def clear_history(request: gr.Request): |
|
logger.info(f"clear_history. ip: {request.client.host}") |
|
return [None] * num_models + [None] * num_models + [""] + [disable_btn] * 5 |
|
|
|
|
|
def share_click(state0, state1, model_selector0, model_selector1, |
|
request: gr.Request): |
|
logger.info(f"share. ip: {request.client.host}") |
|
if state0 is not None and state1 is not None: |
|
vote_last_response( |
|
[state0, state1], "share", [model_selector0, model_selector1], request |
|
) |
|
|
|
|
|
def add_text(state0, state1, text, request: gr.Request): |
|
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") |
|
states = [state0, state1] |
|
|
|
for i in range(num_models): |
|
if states[i] is None: |
|
states[i] = get_default_conv_template("vicuna").copy() |
|
|
|
if len(text) <= 0: |
|
for i in range(num_models): |
|
states[i].skip_next = True |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [""] |
|
+ [ |
|
no_change_btn, |
|
] |
|
* 5 |
|
) |
|
|
|
if args.moderate: |
|
flagged = violates_moderation(text) |
|
if flagged: |
|
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") |
|
for i in range(num_models): |
|
states[i].skip_next = True |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [moderation_msg] |
|
+ [ |
|
no_change_btn, |
|
] |
|
* 5 |
|
) |
|
|
|
text = text[:1536] |
|
for i in range(num_models): |
|
states[i].append_message(states[i].roles[0], text) |
|
states[i].append_message(states[i].roles[1], None) |
|
states[i].skip_next = False |
|
|
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [""] |
|
+ [ |
|
disable_btn, |
|
] |
|
* 5 |
|
) |
|
|
|
|
|
def http_bot_all( |
|
state0, |
|
state1, |
|
model_selector0, |
|
model_selector1, |
|
temperature, |
|
max_new_tokens, |
|
request: gr.Request, |
|
): |
|
logger.info(f"http_bot_all. ip: {request.client.host}") |
|
states = [state0, state1] |
|
model_selector = [model_selector0, model_selector1] |
|
gen = [] |
|
for i in range(num_models): |
|
gen.append( |
|
http_bot(states[i], model_selector[i], temperature, max_new_tokens, request) |
|
) |
|
|
|
chatbots = [None] * num_models |
|
while True: |
|
stop = True |
|
for i in range(num_models): |
|
try: |
|
ret = next(gen[i]) |
|
states[i], chatbots[i] = ret[0], ret[1] |
|
buttons = ret[2:] |
|
stop = False |
|
except StopIteration: |
|
pass |
|
yield states + chatbots + list(buttons) |
|
if stop: |
|
break |
|
|
|
for i in range(10): |
|
if i % 2 == 0: |
|
yield states + chatbots + [disable_btn] * 3 + list(buttons)[3:] |
|
else: |
|
yield states + chatbots + list(buttons) |
|
time.sleep(0.2) |
|
|
|
|
|
def build_side_by_side_ui(): |
|
notice_markdown = """ |
|
# ⚔️ Chatbot Arena ⚔️ |
|
- Chat with state-of-the-art open models **side-by-side** and vote for which one is better! |
|
- [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/h6kCZb72G7) |
|
|
|
### Terms of use |
|
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data for future research.** |
|
The demo works better on desktop devices with a wide screen. |
|
|
|
### Choose two models to chat with |
|
| | | |
|
| ---- | ---- | |
|
| [Vicuna](https://vicuna.lmsys.org): a chat assistant fine-tuned from LLaMA on user-shared conversations by LMSYS. | [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/): a dialogue model for academic research by BAIR | |
|
| [OpenAssistant (oasst)](https://open-assistant.io/): a chat-based assistant for everyone by LAION. | [Dolly](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm): an instruction-tuned open large language model by Databricks. | |
|
| [ChatGLM](https://chatglm.cn/blog): an open bilingual dialogue language model by Tsinghua University | [StableLM](https://github.com/stability-AI/stableLM/): Stability AI language models. | |
|
| [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html): a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford. | [LLaMA](https://arxiv.org/abs/2302.13971): open and efficient foundation language models by Meta. | |
|
""" |
|
|
|
learn_more_markdown = """ |
|
### License |
|
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. |
|
""" |
|
|
|
states = [gr.State() for _ in range(num_models)] |
|
model_selectors = [None] * num_models |
|
chatbots = [None] * num_models |
|
|
|
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") |
|
|
|
with gr.Box(elem_id="share-region"): |
|
with gr.Row(): |
|
for i in range(num_models): |
|
with gr.Column(): |
|
model_selectors[i] = gr.Dropdown( |
|
choices=models, |
|
value=models[i] if len(models) > i else "", |
|
interactive=True, |
|
show_label=False, |
|
).style(container=False) |
|
|
|
with gr.Row(): |
|
for i in range(num_models): |
|
label = "Left" if i == 0 else "Right" |
|
with gr.Column(): |
|
chatbots[i] = grChatbot(label=label, elem_id=f"chatbot{i}", |
|
visible=False).style(height=550) |
|
|
|
with gr.Box() as button_row: |
|
with gr.Row(): |
|
leftvote_btn = gr.Button(value="👈 Left is better", interactive=False) |
|
tie_btn = gr.Button(value="🤝 Tie", interactive=False) |
|
rightvote_btn = gr.Button(value="👉 Right is better", interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=20): |
|
textbox = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter text and press ENTER", |
|
visible=False, |
|
).style(container=False) |
|
with gr.Column(scale=1, min_width=50): |
|
send_btn = gr.Button(value="Send", visible=False) |
|
|
|
with gr.Row() as button_row2: |
|
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
|
share_btn = gr.Button(value="📷 Share") |
|
|
|
with gr.Accordion("Parameters", open=False, visible=True) as parameter_row: |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
|
|
gr.Markdown(learn_more_markdown) |
|
|
|
|
|
btn_list = [leftvote_btn, rightvote_btn, tie_btn, regenerate_btn, clear_btn] |
|
leftvote_btn.click( |
|
leftvote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn], |
|
) |
|
rightvote_btn.click( |
|
rightvote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn], |
|
) |
|
tie_btn.click( |
|
tievote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn], |
|
) |
|
regenerate_btn.click( |
|
regenerate, states, states + chatbots + [textbox] + btn_list |
|
).then( |
|
http_bot_all, |
|
states + model_selectors + [temperature, max_output_tokens], |
|
states + chatbots + btn_list, |
|
) |
|
clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) |
|
|
|
share_js=""" |
|
function (a, b, c, d) { |
|
const captureElement = document.querySelector('#share-region'); |
|
html2canvas(captureElement) |
|
.then(canvas => { |
|
canvas.style.display = 'none' |
|
document.body.appendChild(canvas) |
|
return canvas |
|
}) |
|
.then(canvas => { |
|
const image = canvas.toDataURL('image/png') |
|
const a = document.createElement('a') |
|
a.setAttribute('download', 'chatbot-arena.png') |
|
a.setAttribute('href', image) |
|
a.click() |
|
canvas.remove() |
|
}); |
|
return [a, b, c, d]; |
|
} |
|
""" |
|
share_btn.click(share_click, states + model_selectors, [], _js=share_js) |
|
|
|
for i in range(num_models): |
|
model_selectors[i].change( |
|
clear_history, None, states + chatbots + [textbox] + btn_list |
|
) |
|
|
|
textbox.submit( |
|
add_text, states + [textbox], states + chatbots + [textbox] + btn_list |
|
).then( |
|
http_bot_all, |
|
states + model_selectors + [temperature, max_output_tokens], |
|
states + chatbots + btn_list, |
|
) |
|
send_btn.click( |
|
add_text, states + [textbox], states + chatbots + [textbox] + btn_list |
|
).then( |
|
http_bot_all, |
|
states + model_selectors + [temperature, max_output_tokens], |
|
states + chatbots + btn_list, |
|
) |
|
|
|
return ( |
|
states, |
|
model_selectors, |
|
chatbots, |
|
textbox, |
|
send_btn, |
|
button_row, |
|
button_row2, |
|
parameter_row, |
|
) |
|
|
|
|
|
def build_demo(): |
|
with gr.Blocks( |
|
title="Chat with Open Large Language Models", |
|
theme=gr.themes.Base(), |
|
css=block_css, |
|
) as demo: |
|
with gr.Tabs() as tabs: |
|
with gr.Tab("Single Model", id=0): |
|
( |
|
a_state, |
|
a_model_selector, |
|
a_chatbot, |
|
a_textbox, |
|
a_send_btn, |
|
a_button_row, |
|
a_parameter_row, |
|
) = build_single_model_ui() |
|
a_list = [ |
|
a_state, |
|
a_model_selector, |
|
a_chatbot, |
|
a_textbox, |
|
a_send_btn, |
|
a_button_row, |
|
a_parameter_row, |
|
] |
|
|
|
with gr.Tab("Chatbot Arena", id=1): |
|
( |
|
b_states, |
|
b_model_selectors, |
|
b_chatbots, |
|
b_textbox, |
|
b_send_btn, |
|
b_button_row, |
|
b_button_row2, |
|
b_parameter_row, |
|
) = build_side_by_side_ui() |
|
b_list = ( |
|
b_states |
|
+ b_model_selectors |
|
+ b_chatbots |
|
+ [ |
|
b_textbox, |
|
b_send_btn, |
|
b_button_row, |
|
b_button_row2, |
|
b_parameter_row, |
|
] |
|
) |
|
|
|
url_params = gr.JSON(visible=False) |
|
|
|
if args.model_list_mode == "once": |
|
demo.load( |
|
load_demo, |
|
[url_params], |
|
[tabs] + a_list + b_list, |
|
_js=get_window_url_params, |
|
) |
|
else: |
|
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int) |
|
parser.add_argument("--controller-url", type=str, default="http://localhost:21001") |
|
parser.add_argument("--concurrency-count", type=int, default=10) |
|
parser.add_argument( |
|
"--model-list-mode", type=str, default="once", choices=["once", "reload"] |
|
) |
|
parser.add_argument("--share", action="store_true") |
|
parser.add_argument( |
|
"--moderate", action="store_true", help="Enable content moderation" |
|
) |
|
args = parser.parse_args() |
|
logger.info(f"args: {args}") |
|
|
|
models = get_model_list(args.controller_url) |
|
set_global_vars(args.controller_url, args.moderate, models) |
|
|
|
logger.info(args) |
|
demo = build_demo() |
|
demo.queue( |
|
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False |
|
).launch( |
|
server_name=args.host, server_port=args.port, share=args.share, max_threads=200 |
|
) |
|
|