|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from collections import defaultdict, Counter |
|
import random |
|
import threading |
|
import time |
|
import os |
|
|
|
|
|
llm_list = [ |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
"AXCXEPT/EZO-Common-9B-gemma-2-it", |
|
|
|
|
|
] |
|
|
|
|
|
llm_counts = defaultdict(int) |
|
|
|
|
|
clients = {llm: InferenceClient(llm) for llm in llm_list} |
|
|
|
|
|
def select_llms(): |
|
min_count = min(llm_counts.values()) if llm_counts else 0 |
|
candidates = [llm for llm in llm_list if llm_counts[llm] == min_count] |
|
if len(candidates) < 2: |
|
candidates = llm_list |
|
selected_llms = random.sample(candidates, 2) |
|
for llm in selected_llms: |
|
llm_counts[llm] += 1 |
|
return selected_llms |
|
|
|
|
|
def respond_llm( |
|
message, |
|
history: list[tuple[str, str]], |
|
system_message, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
llm_client, |
|
): |
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
response = "" |
|
|
|
for message in llm_client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
token = message.choices[0].delta.get("content", "") |
|
response += token |
|
yield response |
|
|
|
|
|
VOTE_FILE = "votes.txt" |
|
|
|
|
|
def save_vote(selected_llm): |
|
|
|
with open(VOTE_FILE, "a") as f: |
|
f.write(f"{selected_llm}\n") |
|
return gr.update(visible=True, value="投票ありがとうございました!") |
|
|
|
|
|
def update_leaderboard(): |
|
try: |
|
with open(VOTE_FILE, "r") as f: |
|
votes = f.readlines() |
|
vote_counts = Counter(vote.strip() for vote in votes) |
|
leaderboard = sorted(vote_counts.items(), key=lambda x: x[1], reverse=True) |
|
leaderboard_text = "## リーダーボード\n\n" |
|
for llm, count in leaderboard: |
|
leaderboard_text += f"- {llm}: {count}票\n" |
|
except FileNotFoundError: |
|
leaderboard_text = "まだ投票がありません。" |
|
return leaderboard_text |
|
|
|
|
|
def chat_interface(): |
|
llm1, llm2 = select_llms() |
|
client1 = clients[llm1] |
|
client2 = clients[llm2] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## LLM比較アリーナ") |
|
|
|
with gr.Row(): |
|
gr.Markdown(f"### LLM1: {llm1}") |
|
gr.Markdown(f"### LLM2: {llm2}") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
chat1 = gr.ChatInterface( |
|
lambda message, history, system_message, max_tokens, temperature, top_p: |
|
respond_llm(message, history, system_message, max_tokens, temperature, top_p, client1), |
|
additional_inputs=[ |
|
gr.Textbox(value="あなたはフレンドリーなチャットボットです。", label="システムメッセージ"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大トークン数"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="トップP"), |
|
], |
|
) |
|
with gr.Column(): |
|
chat2 = gr.ChatInterface( |
|
lambda message, history, system_message, max_tokens, temperature, top_p: |
|
respond_llm(message, history, system_message, max_tokens, temperature, top_p, client2), |
|
additional_inputs=[ |
|
gr.Textbox(value="あなたはフレンドリーなチャットボットです。", label="システムメッセージ"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大トークン数"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="トップP"), |
|
], |
|
) |
|
|
|
|
|
with gr.Row(): |
|
vote = gr.Radio([llm1, llm2], label="どちらの応答が良かったですか?") |
|
submit = gr.Button("投票") |
|
result = gr.Textbox(label="", visible=False) |
|
|
|
submit.click(save_vote, inputs=vote, outputs=result) |
|
|
|
|
|
leaderboard = gr.Markdown(update_leaderboard()) |
|
|
|
return demo |
|
|
|
|
|
def refresh_leaderboard(leaderboard_component): |
|
while True: |
|
leaderboard_text = update_leaderboard() |
|
leaderboard_component.value = leaderboard_text |
|
time.sleep(60) |
|
|
|
if __name__ == "__main__": |
|
demo = chat_interface() |
|
|
|
|
|
leaderboard_component = None |
|
for component in demo.blocks: |
|
if isinstance(component, gr.Markdown) and "リーダーボード" in component.value: |
|
leaderboard_component = component |
|
break |
|
|
|
|
|
if leaderboard_component: |
|
threading.Thread(target=refresh_leaderboard, args=(leaderboard_component,), daemon=True).start() |
|
|
|
demo.launch() |
|
|