|
""" |
|
Multimodal Chatbot Arena (side-by-side) tab. |
|
Users chat with two chosen models. |
|
""" |
|
|
|
import json |
|
import os |
|
import time |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
from src.constants import ( |
|
TEXT_MODERATION_MSG, |
|
IMAGE_MODERATION_MSG, |
|
MODERATION_MSG, |
|
CONVERSATION_LIMIT_MSG, |
|
SLOW_MODEL_MSG, |
|
INPUT_CHAR_LEN_LIMIT, |
|
CONVERSATION_TURN_LIMIT, |
|
) |
|
from src.model.model_adapter import get_conversation_template |
|
from src.serve.gradio_block_arena_named import ( |
|
flash_buttons, |
|
share_click, |
|
bot_response_multi, |
|
) |
|
from src.serve.gradio_block_arena_vision import ( |
|
get_vqa_sample, |
|
set_invisible_image, |
|
set_visible_image, |
|
add_image, |
|
moderate_input, |
|
) |
|
from src.serve.gradio_web_server import ( |
|
State, |
|
bot_response, |
|
get_conv_log_filename, |
|
no_change_btn, |
|
enable_btn, |
|
disable_btn, |
|
invisible_btn, |
|
acknowledgment_md, |
|
get_ip, |
|
get_model_description_md, |
|
_prepare_text_with_image, |
|
) |
|
from src.serve.remote_logger import get_remote_logger |
|
from src.utils import ( |
|
build_logger, |
|
moderation_filter, |
|
image_moderation_filter, |
|
) |
|
|
|
|
|
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") |
|
|
|
num_sides = 2 |
|
enable_moderation = False |
|
|
|
|
|
def clear_history_example(request: gr.Request): |
|
logger.info(f"clear_history_example (named). ip: {get_ip(request)}") |
|
return ( |
|
[None] * num_sides |
|
+ [None] * num_sides |
|
+ [invisible_btn] * 4 |
|
+ [disable_btn] * 2 |
|
) |
|
|
|
|
|
def vote_last_response(states, vote_type, model_selectors, request: gr.Request): |
|
filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image) |
|
with open(filename, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"models": [x for x in model_selectors], |
|
"states": [x.dict() for x in states], |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
get_remote_logger().log(data) |
|
|
|
|
|
def leftvote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"leftvote (named). ip: {get_ip(request)}") |
|
vote_last_response( |
|
[state0, state1], "leftvote", [model_selector0, model_selector1], request |
|
) |
|
return (None,) + (disable_btn,) * 4 |
|
|
|
|
|
def rightvote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"rightvote (named). ip: {get_ip(request)}") |
|
vote_last_response( |
|
[state0, state1], "rightvote", [model_selector0, model_selector1], request |
|
) |
|
return (None,) + (disable_btn,) * 4 |
|
|
|
|
|
def tievote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"tievote (named). ip: {get_ip(request)}") |
|
vote_last_response( |
|
[state0, state1], "tievote", [model_selector0, model_selector1], request |
|
) |
|
return (None,) + (disable_btn,) * 4 |
|
|
|
|
|
def bothbad_vote_last_response( |
|
state0, state1, model_selector0, model_selector1, request: gr.Request |
|
): |
|
logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") |
|
vote_last_response( |
|
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request |
|
) |
|
return (None,) + (disable_btn,) * 4 |
|
|
|
|
|
def regenerate(state0, state1, request: gr.Request): |
|
logger.info(f"regenerate (named). ip: {get_ip(request)}") |
|
states = [state0, state1] |
|
if state0.regen_support and state1.regen_support: |
|
for i in range(num_sides): |
|
states[i].conv.update_last_message(None) |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [None] |
|
+ [disable_btn] * 6 |
|
) |
|
states[0].skip_next = True |
|
states[1].skip_next = True |
|
return ( |
|
states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 |
|
) |
|
|
|
|
|
def clear_history(request: gr.Request): |
|
logger.info(f"clear_history (named). ip: {get_ip(request)}") |
|
return ( |
|
[None] * num_sides |
|
+ [None] * num_sides |
|
+ [None] |
|
+ [invisible_btn] * 4 |
|
+ [disable_btn] * 2 |
|
) |
|
|
|
|
|
def add_text( |
|
state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request |
|
): |
|
text, images = chat_input["text"], chat_input["files"] |
|
ip = get_ip(request) |
|
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") |
|
states = [state0, state1] |
|
model_selectors = [model_selector0, model_selector1] |
|
|
|
|
|
for i in range(num_sides): |
|
if states[i] is None: |
|
states[i] = State(model_selectors[i], is_vision=True) |
|
|
|
if len(text) <= 0: |
|
for i in range(num_sides): |
|
states[i].skip_next = True |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [None] |
|
+ [ |
|
no_change_btn, |
|
] |
|
* 6 |
|
) |
|
|
|
model_list = [states[i].model_name for i in range(num_sides)] |
|
all_conv_text_left = states[0].conv.get_prompt() |
|
all_conv_text_right = states[0].conv.get_prompt() |
|
all_conv_text = ( |
|
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text |
|
) |
|
|
|
text, image_flagged, csam_flag = moderate_input( |
|
text, all_conv_text, model_list, images, ip |
|
) |
|
|
|
conv = states[0].conv |
|
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: |
|
logger.info(f"conversation turn limit. ip: {ip}. text: {text}") |
|
for i in range(num_sides): |
|
states[i].skip_next = True |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [{"text": CONVERSATION_LIMIT_MSG}] |
|
+ [ |
|
no_change_btn, |
|
] |
|
* 6 |
|
) |
|
|
|
if image_flagged: |
|
logger.info(f"image flagged. ip: {ip}. text: {text}") |
|
for i in range(num_sides): |
|
states[i].skip_next = True |
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [{"text": IMAGE_MODERATION_MSG}] |
|
+ [ |
|
no_change_btn, |
|
] |
|
* 6 |
|
) |
|
|
|
text = text[:INPUT_CHAR_LEN_LIMIT] |
|
for i in range(num_sides): |
|
post_processed_text = _prepare_text_with_image( |
|
states[i], text, images, csam_flag=csam_flag |
|
) |
|
logger.info(f"msg={post_processed_text}") |
|
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) |
|
states[i].conv.append_message(states[i].conv.roles[1], None) |
|
states[i].skip_next = False |
|
|
|
return ( |
|
states |
|
+ [x.to_gradio_chatbot() for x in states] |
|
+ [None] |
|
+ [ |
|
disable_btn, |
|
] |
|
* 6 |
|
) |
|
|
|
|
|
def build_side_by_side_vision_ui_named(models, random_questions=None): |
|
notice_markdown = """ |
|
# ⚔️ Vision Arena ⚔️ : Benchmarking FIRE-LLaVA VS. LLaVA-NeXT |
|
|
|
## 📜 Rules |
|
- Chat with any two models side-by-side and vote! |
|
- You can continue chatting for multiple rounds. |
|
- Click "Clear history" to start a new round. |
|
- You can only chat with <span style='color: #DE3163; font-weight: bold'>one image per conversation</span>. You can upload images less than 15MB. |
|
|
|
**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** |
|
|
|
## 🤖 Choose two models to compare |
|
""" |
|
|
|
states = [gr.State() for _ in range(num_sides)] |
|
model_selectors = [None] * num_sides |
|
chatbots = [None] * num_sides |
|
|
|
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2, visible=False) as image_column: |
|
imagebox = gr.Image( |
|
type="pil", |
|
show_label=False, |
|
interactive=False, |
|
) |
|
|
|
with gr.Column(scale=5): |
|
with gr.Group(elem_id="share-region-anony"): |
|
with gr.Accordion( |
|
f"🔍 Expand to see the descriptions of {len(models)} models", |
|
open=False, |
|
): |
|
model_description_md = get_model_description_md(models) |
|
gr.Markdown( |
|
model_description_md, elem_id="model_description_markdown" |
|
) |
|
|
|
with gr.Row(): |
|
for i in range(num_sides): |
|
with gr.Column(): |
|
model_names_dict = { |
|
"llava-fire": 'FIRE-LLaVA', |
|
"llava-original": "LLaVA-NeXT-LLaMA-3-8B" |
|
} |
|
model_choices = [] |
|
for model_value in models: |
|
if model_value in model_names_dict: |
|
model_choices.append((model_names_dict[model_value], model_value)) |
|
else: |
|
model_choices.append((model_value, model_value)) |
|
model_selectors[i] = gr.Dropdown( |
|
choices=model_choices, |
|
value=models[i] if len(models) > i else "", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
) |
|
|
|
with gr.Row(): |
|
for i in range(num_sides): |
|
label = "Model A" if i == 0 else "Model B" |
|
with gr.Column(): |
|
chatbots[i] = gr.Chatbot( |
|
label=label, |
|
elem_id=f"chatbot", |
|
height=550, |
|
show_copy_button=True, |
|
) |
|
|
|
with gr.Row(): |
|
leftvote_btn = gr.Button( |
|
value="👈 A is better", visible=False, interactive=False |
|
) |
|
rightvote_btn = gr.Button( |
|
value="👉 B is better", visible=False, interactive=False |
|
) |
|
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) |
|
bothbad_btn = gr.Button( |
|
value="👎 Both are bad", visible=False, interactive=False |
|
) |
|
with gr.Row(): |
|
recommendation = gr.Textbox( |
|
visible=True, |
|
label="Teacher generated feedback:", |
|
show_copy_button=True, |
|
) |
|
with gr.Row(): |
|
textbox = gr.MultimodalTextbox( |
|
file_types=["image"], |
|
show_label=False, |
|
placeholder="Click add or drop your image here", |
|
container=True, |
|
elem_id="input_box", |
|
) |
|
|
|
with gr.Row() as button_row: |
|
if random_questions: |
|
global vqa_samples |
|
with open(random_questions, "r") as f: |
|
vqa_samples = json.load(f) |
|
random_btn = gr.Button(value="🎲 Random Example", interactive=True) |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
|
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) |
|
share_btn = gr.Button(value="📷 Share") |
|
with gr.Row(): |
|
gr.Examples(examples=[ |
|
[ |
|
{ |
|
"files": ["assets/image_50.png"], |
|
"text": "Please directly answer the question and provide the correct option letter, e.g., A, B, C, D.\nQuestion: As shown in the figure, then angle COE = ()\nChoices:\nA:30°\nB:140°\nC:50°\nD:60°" |
|
}, |
|
"Your answer is incorrect. The question asks for the angle COE in the context of the figure provided. Consider the relationships between the angles and the lines in the figure to find the correct answer. Try again by analyzing the given diagram more carefully." |
|
], |
|
|
|
[ |
|
{ |
|
"files": ["assets/magnetic.png"], |
|
"text": """Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. |
|
Question: Will these magnets attract or repel each other? |
|
Choices: |
|
A. repel |
|
B. attract""" |
|
}, |
|
"""You correctly identified that the letters "N" and "S" represent opposite poles of a magnet. However, your conclusion that they repel each other is incorrect. Please reconsider your answer with this information in mind.""" |
|
], |
|
[ |
|
{ |
|
"files": ["assets/fox.png"], |
|
"text": """Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. |
|
Question: Which of the following organisms is the primary consumer in this food web? |
|
Choices: |
|
A. Arctic fox |
|
B. rough-legged hawk |
|
C. mushroom""" |
|
}, |
|
"""You correctly identified that the primary consumers are the organisms that feed directly on the producers. However, your answer is incorrect. The mushroom is not a primary consumer; it is a decomposer. Look again at the food web and identify which organisms are shown as consuming the producers directly. Try to find the correct option among the given choices.""" |
|
], |
|
[ |
|
{ |
|
"files": ["assets/test_11407.jpg"], |
|
"text": """Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. |
|
Question: 如图,△ABC中,AD为中线,AD⊥AC,∠BAD=30°,AB=3,则AC长() |
|
Choices: |
|
A. 2.5 |
|
B. 2 |
|
C. 1 |
|
D. 1.5""" |
|
}, |
|
"" |
|
], |
|
],inputs=[textbox, recommendation]) |
|
with gr.Accordion("Parameters", open=False) as parameter_row: |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1.0, |
|
step=0.1, |
|
interactive=True, |
|
label="Top P", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=16, |
|
maximum=2048, |
|
value=1024, |
|
step=64, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
|
|
gr.Markdown(acknowledgment_md, elem_id="ack_markdown") |
|
|
|
|
|
btn_list = [ |
|
leftvote_btn, |
|
rightvote_btn, |
|
tie_btn, |
|
bothbad_btn, |
|
regenerate_btn, |
|
clear_btn, |
|
] |
|
leftvote_btn.click( |
|
leftvote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], |
|
) |
|
rightvote_btn.click( |
|
rightvote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], |
|
) |
|
tie_btn.click( |
|
tievote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], |
|
) |
|
bothbad_btn.click( |
|
bothbad_vote_last_response, |
|
states + model_selectors, |
|
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], |
|
) |
|
regenerate_btn.click( |
|
regenerate, states, states + chatbots + [textbox] + btn_list |
|
).then( |
|
bot_response_multi, |
|
states + [temperature, top_p, max_output_tokens], |
|
states + chatbots + btn_list, |
|
).then( |
|
flash_buttons, [], btn_list |
|
) |
|
clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) |
|
|
|
share_js = """ |
|
function (a, b, c, d) { |
|
const captureElement = document.querySelector('#share-region-named'); |
|
html2canvas(captureElement) |
|
.then(canvas => { |
|
canvas.style.display = 'none' |
|
document.body.appendChild(canvas) |
|
return canvas |
|
}) |
|
.then(canvas => { |
|
const image = canvas.toDataURL('image/png') |
|
const a = document.createElement('a') |
|
a.setAttribute('download', 'chatbot-arena.png') |
|
a.setAttribute('href', image) |
|
a.click() |
|
canvas.remove() |
|
}); |
|
return [a, b, c, d]; |
|
} |
|
""" |
|
share_btn.click(share_click, states + model_selectors, [], js=share_js) |
|
|
|
for i in range(num_sides): |
|
model_selectors[i].change( |
|
clear_history, None, states + chatbots + [textbox] + btn_list |
|
).then(set_visible_image, [textbox], [image_column]) |
|
|
|
textbox.input(add_image, [textbox], [imagebox]).then( |
|
set_visible_image, [textbox], [image_column] |
|
).then(clear_history_example, None, states + chatbots + btn_list) |
|
|
|
textbox.submit( |
|
add_text, |
|
states + model_selectors + [textbox], |
|
states + chatbots + [textbox] + btn_list, |
|
).then(set_invisible_image, [], [image_column]).then( |
|
bot_response_multi, |
|
states + [temperature, top_p, max_output_tokens], |
|
states + chatbots + btn_list, |
|
).then( |
|
flash_buttons, [], btn_list |
|
) |
|
|
|
if random_questions: |
|
random_btn.click( |
|
get_vqa_sample, |
|
[], |
|
[textbox, imagebox], |
|
).then(set_visible_image, [textbox], [image_column]).then( |
|
clear_history_example, None, states + chatbots + btn_list |
|
) |
|
|
|
return states + model_selectors |
|
|