Spaces:
Running
Running
import gradio as gr | |
from functools import lru_cache | |
import random | |
import requests | |
import logging | |
import arena_config | |
import plotly.graph_objects as go | |
from typing import Dict | |
from leaderboard import get_current_leaderboard, update_leaderboard | |
# Initialize logging for errors only | |
logging.basicConfig(level=logging.ERROR) | |
logger = logging.getLogger(__name__) | |
# Function to get available models (using predefined list) | |
def get_available_models(): | |
return [model[0] for model in arena_config.APPROVED_MODELS] | |
# Function to call Ollama API with caching | |
def call_ollama_api(model, prompt): | |
payload = { | |
"model": model, | |
"messages": [{"role": "user", "content": prompt}], | |
} | |
try: | |
response = requests.post( | |
f"{arena_config.API_URL}/v1/chat/completions", | |
headers=arena_config.HEADERS, | |
json=payload, | |
timeout=100 | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["choices"][0]["message"]["content"] | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error calling Ollama API for model {model}: {e}") | |
return f"Error: Unable to get response from the model." | |
# Generate responses using two randomly selected models | |
def generate_responses(prompt): | |
available_models = get_available_models() | |
if len(available_models) < 2: | |
return "Error: Not enough models available", "Error: Not enough models available", None, None | |
selected_models = random.sample(available_models, 2) | |
model_a, model_b = selected_models | |
model_a_response = call_ollama_api(model_a, prompt) | |
model_b_response = call_ollama_api(model_b, prompt) | |
return model_a_response, model_b_response, model_a, model_b | |
def battle_arena(prompt): | |
response_a, response_b, model_a, model_b = generate_responses(prompt) | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
# Format responses for gr.Chatbot | |
response_a_formatted = [{"role": "assistant", "content": response_a}] | |
response_b_formatted = [{"role": "assistant", "content": response_b}] | |
if random.choice([True, False]): | |
return ( | |
response_a_formatted, response_b_formatted, model_a, model_b, | |
gr.update(label=nickname_a, value=response_a_formatted), | |
gr.update(label=nickname_b, value=response_b_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}") | |
) | |
else: | |
return ( | |
response_b_formatted, response_a_formatted, model_b, model_a, | |
gr.update(label=nickname_a, value=response_b_formatted), | |
gr.update(label=nickname_b, value=response_a_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}") | |
) | |
def record_vote(prompt, left_response, right_response, left_model, right_model, choice): | |
# Check if outputs are generated | |
if not left_response or not right_response or not left_model or not right_model: | |
return ( | |
"Please generate responses before voting.", | |
gr.update(), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(visible=False), | |
gr.update() | |
) | |
winner = left_model if choice == "Left is better" else right_model | |
loser = right_model if choice == "Left is better" else left_model | |
# Update the leaderboard | |
battle_results = update_leaderboard(winner, loser) | |
result_message = f""" | |
π Vote recorded! You're awesome! π | |
π΅ In the left corner: {get_human_readable_name(left_model)} | |
π΄ In the right corner: {get_human_readable_name(right_model)} | |
π And the champion you picked is... {get_human_readable_name(winner)}! π₯ | |
""" | |
return ( | |
gr.update(value=result_message, visible=True), # Show result as Markdown | |
get_leaderboard(), # Update leaderboard | |
gr.update(interactive=False), # Disable left vote button | |
gr.update(interactive=False), # Disable right vote button | |
gr.update(visible=True), # Show model names | |
get_leaderboard_chart() # Update leaderboard chart | |
) | |
def get_leaderboard(): | |
battle_results = get_current_leaderboard() | |
# Calculate scores for each model | |
for model, results in battle_results.items(): | |
total_battles = results["wins"] + results["losses"] | |
if total_battles > 0: | |
win_rate = results["wins"] / total_battles | |
# Score formula: win_rate * (1 - 1 / (total_battles + 1)) | |
# This gives more weight to models with more battles | |
results["score"] = win_rate * (1 - 1 / (total_battles + 1)) | |
else: | |
results["score"] = 0 | |
# Sort results by score, then by total battles | |
sorted_results = sorted( | |
battle_results.items(), | |
key=lambda x: (x[1]["score"], x[1]["wins"] + x[1]["losses"]), | |
reverse=True | |
) | |
leaderboard = """ | |
<style> | |
.leaderboard-table { | |
width: 100%; | |
border-collapse: collapse; | |
font-family: Arial, sans-serif; | |
} | |
.leaderboard-table th, .leaderboard-table td { | |
border: 1px solid #ddd; | |
padding: 8px; | |
text-align: left; | |
} | |
.leaderboard-table th { | |
background-color: rgba(255, 255, 255, 0.1); | |
font-weight: bold; | |
} | |
.rank-column { | |
width: 60px; | |
text-align: center; | |
} | |
.opponent-details { | |
font-size: 0.9em; | |
color: #888; | |
} | |
</style> | |
<table class='leaderboard-table'> | |
<tr> | |
<th class='rank-column'>Rank</th> | |
<th>Model</th> | |
<th>Score</th> | |
<th>Wins</th> | |
<th>Losses</th> | |
<th>Win Rate</th> | |
<th>Total Battles</th> | |
<th>Top Rival</th> | |
<th>Toughest Opponent</th> | |
</tr> | |
""" | |
for index, (model, results) in enumerate(sorted_results, start=1): | |
total_battles = results["wins"] + results["losses"] | |
win_rate = (results["wins"] / total_battles * 100) if total_battles > 0 else 0 | |
if index == 1: | |
rank_display = "π₯" | |
elif index == 2: | |
rank_display = "π₯" | |
elif index == 3: | |
rank_display = "π₯" | |
else: | |
rank_display = f"{index}" | |
# Find top rival (most wins against) | |
top_rival = max(results["opponents"].items(), key=lambda x: x[1]["wins"], default=(None, {"wins": 0})) | |
top_rival_name = get_human_readable_name(top_rival[0]) if top_rival[0] else "N/A" | |
top_rival_wins = top_rival[1]["wins"] | |
# Find toughest opponent (most losses against) | |
toughest_opponent = max(results["opponents"].items(), key=lambda x: x[1]["losses"], default=(None, {"losses": 0})) | |
toughest_opponent_name = get_human_readable_name(toughest_opponent[0]) if toughest_opponent[0] else "N/A" | |
toughest_opponent_losses = toughest_opponent[1]["losses"] | |
leaderboard += f""" | |
<tr> | |
<td class='rank-column'>{rank_display}</td> | |
<td>{get_human_readable_name(model)}</td> | |
<td>{results['score']:.4f}</td> | |
<td>{results['wins']}</td> | |
<td>{results['losses']}</td> | |
<td>{win_rate:.2f}%</td> | |
<td>{total_battles}</td> | |
<td class='opponent-details'>{top_rival_name} (W: {top_rival_wins})</td> | |
<td class='opponent-details'>{toughest_opponent_name} (L: {toughest_opponent_losses})</td> | |
</tr> | |
""" | |
leaderboard += "</table>" | |
return leaderboard | |
def get_leaderboard_chart(): | |
battle_results = get_current_leaderboard() | |
# Calculate scores and sort results | |
for model, results in battle_results.items(): | |
total_battles = results["wins"] + results["losses"] | |
if total_battles > 0: | |
win_rate = results["wins"] / total_battles | |
results["score"] = win_rate * (1 - 1 / (total_battles + 1)) | |
else: | |
results["score"] = 0 | |
sorted_results = sorted( | |
battle_results.items(), | |
key=lambda x: (x[1]["score"], x[1]["wins"] + x[1]["losses"]), | |
reverse=True | |
) | |
models = [get_human_readable_name(model) for model, _ in sorted_results] | |
wins = [results["wins"] for _, results in sorted_results] | |
losses = [results["losses"] for _, results in sorted_results] | |
scores = [results["score"] for _, results in sorted_results] | |
fig = go.Figure() | |
# Stacked Bar chart for Wins and Losses | |
fig.add_trace(go.Bar( | |
x=models, | |
y=wins, | |
name='Wins', | |
marker_color='#22577a' | |
)) | |
fig.add_trace(go.Bar( | |
x=models, | |
y=losses, | |
name='Losses', | |
marker_color='#38a3a5' | |
)) | |
# Line chart for Scores | |
fig.add_trace(go.Scatter( | |
x=models, | |
y=scores, | |
name='Score', | |
yaxis='y2', | |
line=dict(color='#ff7f0e', width=2) | |
)) | |
# Update layout for full-width, increased height, and secondary y-axis | |
fig.update_layout( | |
title='Model Performance', | |
xaxis_title='Models', | |
yaxis_title='Number of Battles', | |
yaxis2=dict( | |
title='Score', | |
overlaying='y', | |
side='right' | |
), | |
barmode='stack', | |
height=800, | |
width=1450, | |
autosize=True, | |
legend=dict( | |
orientation='h', | |
yanchor='bottom', | |
y=1.02, | |
xanchor='right', | |
x=1 | |
) | |
) | |
return fig | |
def new_battle(): | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
return ( | |
"", # Reset prompt_input | |
gr.update(value=[], label=nickname_a), # Reset left Chatbot | |
gr.update(value=[], label=nickname_b), # Reset right Chatbot | |
None, | |
None, | |
gr.update(interactive=False, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=False, value=f"Vote for {nickname_b}"), | |
gr.update(value="", visible=False), | |
gr.update(), | |
gr.update(visible=False), | |
gr.update() | |
) | |
# Add this new function | |
def get_human_readable_name(model_name: str) -> str: | |
model_dict = dict(arena_config.APPROVED_MODELS) | |
return model_dict.get(model_name, model_name) | |
# Add this new function to randomly select a prompt | |
def random_prompt(): | |
return random.choice(arena_config.example_prompts) | |
# Initialize Gradio Blocks | |
with gr.Blocks(css=""" | |
#dice-button { | |
min-height: 90px; | |
font-size: 35px; | |
} | |
""") as demo: | |
gr.Markdown(arena_config.ARENA_NAME) | |
gr.Markdown(arena_config.ARENA_DESCRIPTION) | |
# Battle Arena Tab | |
with gr.Tab("Battle Arena"): | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
scale=20 | |
) | |
random_prompt_btn = gr.Button("π²", scale=1, elem_id="dice-button") | |
gr.Markdown("<br>") | |
# Add the random prompt button functionality | |
random_prompt_btn.click( | |
random_prompt, | |
outputs=prompt_input | |
) | |
submit_btn = gr.Button("Generate Responses", variant="primary") | |
with gr.Row(): | |
left_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
right_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
with gr.Row(): | |
left_vote_btn = gr.Button(f"Vote for {left_output.label}", interactive=False) | |
right_vote_btn = gr.Button(f"Vote for {right_output.label}", interactive=False) | |
result = gr.Textbox(label="Result", interactive=False, visible=False) | |
with gr.Row(visible=False) as model_names_row: | |
left_model = gr.Textbox(label="π΅ Left Model", interactive=False) | |
right_model = gr.Textbox(label="π΄ Right Model", interactive=False) | |
new_battle_btn = gr.Button("New Battle") | |
# Leaderboard Tab | |
with gr.Tab("Leaderboard"): | |
leaderboard = gr.HTML(label="Leaderboard") | |
# Performance Chart Tab | |
with gr.Tab("Performance Chart"): | |
leaderboard_chart = gr.Plot(label="Model Performance Chart") | |
# Define interactions | |
submit_btn.click( | |
battle_arena, | |
inputs=prompt_input, | |
outputs=[left_output, right_output, left_model, right_model, | |
left_output, right_output, left_vote_btn, right_vote_btn] | |
) | |
left_vote_btn.click( | |
lambda *args: record_vote(*args, "Left is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, left_vote_btn, | |
right_vote_btn, model_names_row, leaderboard_chart] | |
) | |
right_vote_btn.click( | |
lambda *args: record_vote(*args, "Right is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, left_vote_btn, | |
right_vote_btn, model_names_row, leaderboard_chart] | |
) | |
new_battle_btn.click( | |
new_battle, | |
outputs=[prompt_input, left_output, right_output, left_model, | |
right_model, left_vote_btn, right_vote_btn, | |
result, leaderboard, model_names_row, leaderboard_chart] | |
) | |
# Update leaderboard and chart on launch | |
demo.load(get_leaderboard, outputs=leaderboard) | |
demo.load(get_leaderboard_chart, outputs=leaderboard_chart) | |
if __name__ == "__main__": | |
demo.launch() | |