Spaces:
Running
Running
import pandas as pd | |
import streamlit as st | |
import time | |
from collections import defaultdict | |
from streamlit_image_select import image_select | |
import requests | |
import os | |
st.set_page_config(layout="wide") | |
description = """ | |
# Anime Leaderboard | |
Text to Image (Anime/Illustration) Generation Leaderboard. | |
This leaderboard is just for fun and does not reflect the actual performance of the models. | |
## How to Use | |
- Select the image that best reflects the given prompt. | |
- Your selections contribute to the global leaderboard. | |
- View your personal leaderboard after making at least 30 selections. | |
## Data | |
- Data Source: [nyanko7/image-samples](https://huggingface.co/datasets/nyanko7/image-samples) | |
- Calling for submissions: [open issue](https://huggingface.co/spaces/nyanko7/text-to-anime-arena/discussions/new) or contact me to submit your model | |
- Warning: Some images may contain NSFW content. | |
""" | |
if 'selections' not in st.session_state: | |
st.session_state['selections'] = [] | |
if 'selection_count' not in st.session_state: | |
st.session_state['selection_count'] = 0 | |
if 'last_pair' not in st.session_state: | |
st.session_state['last_pair'] = None | |
if 'user_id' not in st.session_state: | |
st.session_state['user_id'] = None | |
st.sidebar.markdown(description) | |
SERVER_URL = os.getenv("W_SERVER") # Replace with your actual server URL | |
def get_next_pair(): | |
try: | |
response = requests.get(f"{SERVER_URL}/next_pair") | |
if response.status_code == 200: | |
return response.json() | |
else: | |
print(response) | |
st.error("Failed to fetch next pair from server") | |
return None | |
except Exception as e: | |
print(e) | |
st.error("Failed to fetch next pair from server") | |
return None | |
if "pair" not in st.session_state: | |
st.session_state["pair"] = get_next_pair() | |
def submit_selection(selection_result): | |
headers = {} | |
if st.session_state['user_id']: | |
headers['User-ID'] = st.session_state['user_id'] | |
try: | |
response = requests.post(f"{SERVER_URL}/submit_selection", json=selection_result, headers=headers) | |
if response.status_code == 200: | |
response_data = response.json() | |
if 'user_id' in response_data: | |
st.session_state['user_id'] = response_data['user_id'] | |
else: | |
st.error(f"Failed to submit selection to server") | |
except Exception as e: | |
st.error(f"Failed to submit selection to server") | |
def get_leaderboard_data(): | |
try: | |
response = requests.get(f"{SERVER_URL}/leaderboard") | |
if response.status_code == 200: | |
return response.json() | |
else: | |
st.error("Failed to fetch leaderboard data from server") | |
return None | |
except Exception as e: | |
st.error("Failed to fetch leaderboard data from server") | |
return None | |
import io | |
from PIL import Image | |
def open_image_from_url(image_url): | |
response = requests.get(image_url, stream=True) | |
response.raise_for_status() | |
return Image.open(io.BytesIO(response.content)) | |
def arena(): | |
pair = st.session_state["pair"] | |
image_url1, model_a = pair["image1"], pair["model_a"] | |
image_url2, model_b = pair["image2"], pair["model_b"] | |
prompt = pair["prompt"] | |
st.markdown(f"**Which image best reflects this prompt?**") | |
st.info( | |
f""" | |
Prompt: {prompt} | |
""", | |
icon="⏳", | |
) | |
# read image datafrom url | |
image_a = open_image_from_url(image_url1) | |
image_b = open_image_from_url(image_url2) | |
images = [image_a, image_b] | |
models = [model_a, model_b] | |
idx = image_select( | |
label="Select the image you prefer", | |
images=images, | |
index=-1, | |
center=True, | |
height=700, | |
return_value="index" | |
) | |
if st.button("Skip"): | |
st.session_state["pair"] = get_next_pair() | |
st.rerun(scope="fragment") | |
if "last_state" in st.session_state and st.session_state["last_state"] is not None: | |
st.markdown(st.session_state["last_state"]) | |
if idx != -1: | |
selection_result = { | |
"model_a": model_a, | |
"model_b": model_b, | |
"winner": "model_a" if idx == 0 else "model_b", | |
"time": time.time() | |
} | |
st.session_state["selections"].append(selection_result) | |
st.session_state["selection_count"] += 1 | |
st.session_state["last_state"] = f"[Selection #{st.session_state['selection_count']}] You selected Image `#{idx+1}` - Model: {models[idx]}" | |
submit_selection(selection_result) | |
st.session_state["pair"] = get_next_pair() | |
st.rerun(scope="fragment") | |
def leaderboard(): | |
data = get_leaderboard_data() | |
if data is None: | |
return | |
st.markdown("## Global Leaderboard") | |
st.markdown(""" | |
This leaderboard shows the performance of different models based on user selections. | |
- **Elo Rating**: A relative rating system. Higher scores indicate better performance. | |
- **Win Rate**: The percentage of times a model was chosen when presented. | |
- **#Selections**: Total number of times this model was presented in a pair. | |
""") | |
st.warning("This leaderboard is just for fun and **does not reflect the actual performance of the models.**") | |
df = pd.DataFrame(data["leaderboard"])[["Model", "Elo Rating", "Win Rate", "#Selections"]].reset_index(drop=True) | |
st.dataframe(df, hide_index=True) | |
def my_leaderboard(): | |
if "selections" not in st.session_state or len(st.session_state["selections"]) < 30: | |
st.markdown("Select over 30 images to see your personal leaderboard") | |
uploaded_files = st.file_uploader("Or load your previous selections:", accept_multiple_files=False) | |
if uploaded_files: | |
logs = pd.read_csv(uploaded_files) | |
if "Unnamed: 0" in logs.columns: | |
logs.drop(columns=["Unnamed: 0"], inplace=True) | |
st.session_state["selections"] = logs.to_dict(orient="records") | |
st.rerun() | |
return | |
selections = pd.DataFrame(st.session_state["selections"]) | |
st.markdown("## Personal Leaderboard") | |
st.markdown(""" | |
This leaderboard is based on your personal selections. | |
- **Elo Rating**: Calculated from your choices. Higher scores indicate models you prefer. | |
- **Win Rate**: The percentage of times you chose each model when it was presented. | |
- **#Selections**: Number of times you've seen this model in a pair. | |
""") | |
elo_ratings = compute_elo(selections.to_dict('records')) | |
win_rates = compute_win_rates(selections.to_dict('records')) | |
selection_counts = compute_selection_counts(selections.to_dict('records')) | |
data = [] | |
for model in set(selections['model_a'].unique()) | set(selections['model_b'].unique()): | |
data.append({ | |
"Model": model, | |
"Elo Rating": round(elo_ratings[model], 2), | |
"Win Rate": f"{win_rates[model]*100:.2f}%", | |
"#Selections": selection_counts[model] | |
}) | |
df = pd.DataFrame(data) | |
df = df.sort_values("Elo Rating", ascending=False) | |
df = df[["Model", "Elo Rating", "Win Rate", "#Selections"]].reset_index(drop=True) | |
st.dataframe(df, hide_index=True) | |
st.markdown("## Your Recent Selections") | |
st.dataframe(selections.tail(20)) | |
# download data | |
st.download_button('Download your selection data as CSV', selections.to_csv().encode('utf-8'), "my_selections.csv", "text/csv") | |
def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): | |
rating = defaultdict(lambda: INIT_RATING) | |
for battle in battles: | |
model_a, model_b, winner = battle['model_a'], battle['model_b'], battle['winner'] | |
ra, rb = rating[model_a], rating[model_b] | |
ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) | |
eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) | |
sa = 1 if winner == "model_a" else 0 if winner == "model_b" else 0.5 | |
rating[model_a] += K * (sa - ea) | |
rating[model_b] += K * (1 - sa - eb) | |
return rating | |
def compute_win_rates(battles): | |
win_counts = defaultdict(int) | |
battle_counts = defaultdict(int) | |
for battle in battles: | |
model_a, model_b, winner = battle['model_a'], battle['model_b'], battle['winner'] | |
if winner == "model_a": | |
win_counts[model_a] += 1 | |
elif winner == "model_b": | |
win_counts[model_b] += 1 | |
battle_counts[model_a] += 1 | |
battle_counts[model_b] += 1 | |
return {model: win_counts[model] / battle_counts[model] if battle_counts[model] > 0 else 0 | |
for model in set(win_counts.keys()) | set(battle_counts.keys())} | |
def compute_selection_counts(battles): | |
selection_counts = defaultdict(int) | |
for battle in battles: | |
selection_counts[battle['model_a']] += 1 | |
selection_counts[battle['model_b']] += 1 | |
return selection_counts | |
pages = [ | |
st.Page(arena), | |
st.Page(leaderboard), | |
st.Page(my_leaderboard) | |
] | |
st.navigation(pages).run() |