Spaces:
Sleeping
Sleeping
import json | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import numpy as np | |
import requests | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
class LeaderboardModel(BaseModel): | |
model_name: str | |
type: str | |
model_link: Optional[str] = None | |
scores: Dict[str, float] | |
co2_cost: Optional[float] = None | |
class LeaderboardData(BaseModel): | |
models: List[LeaderboardModel] | |
updated_at: str | |
app = FastAPI( | |
title="LLM Leaderboard API", | |
description="API for serving Open LLM Leaderboard data", | |
version="1.0.0" | |
) | |
# Add CORS middleware to allow requests from your Gradio app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # For production, specify your exact frontend URL | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Cache for leaderboard data | |
cached_data = None | |
cache_file = Path("leaderboard_cache.json") | |
def fetch_external_leaderboard_data(refresh: bool = False) -> Optional[Dict]: | |
""" | |
Fetch leaderboard data from external sources like HuggingFace. | |
Uses local cache if available and refresh is False. | |
""" | |
global cached_data | |
if not refresh and cached_data: | |
return cached_data | |
if not refresh and cache_file.exists(): | |
try: | |
with open(cache_file) as f: | |
cached_data = json.load(f) | |
return cached_data | |
except: | |
pass # Fall back to fetching if cache read fails | |
try: | |
# Try different endpoints that might contain leaderboard data | |
endpoints = [ | |
"https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard/raw/main/leaderboard_data.json", | |
"https://huggingface.co/api/spaces/HuggingFaceH4/open_llm_leaderboard/api/get_results", | |
] | |
for url in endpoints: | |
response = requests.get(url) | |
if response.status_code == 200: | |
data = response.json() | |
cached_data = data | |
with open(cache_file, "w") as f: | |
json.dump(data, f) | |
return data | |
# If all endpoints fail, return None | |
return None | |
except Exception as e: | |
print(f"Error fetching external leaderboard data: {e}") | |
return None | |
def generate_sample_data() -> Dict: | |
""" | |
Generate sample leaderboard data when external data can't be fetched. | |
""" | |
models = [ | |
{"model_name": "meta-llama/llama-3-70b-instruct", "type": "open"}, | |
{"model_name": "mistralai/Mistral-7B-Instruct-v0.3", "type": "open"}, | |
{"model_name": "google/gemma-7b-it", "type": "open"}, | |
{"model_name": "Qwen/Qwen2-7B-Instruct", "type": "open"}, | |
{"model_name": "anthropic/claude-3-opus", "type": "closed", "external_link": "https://www.anthropic.com/claude"}, | |
{"model_name": "OpenAI/gpt-4o", "type": "closed", "external_link": "https://openai.com/gpt-4"}, | |
{"model_name": "01-ai/Yi-1.5-34B-Chat", "type": "open"}, | |
{"model_name": "google/gemma-2b", "type": "open"}, | |
{"model_name": "microsoft/phi-3-mini-4k-instruct", "type": "open"}, | |
{"model_name": "microsoft/phi-3-mini-128k-instruct", "type": "open"}, | |
{"model_name": "stabilityai/stable-beluga-7b", "type": "open"}, | |
{"model_name": "togethercomputer/RedPajama-INCITE-7B-Instruct", "type": "open"}, | |
{"model_name": "databricks/dbrx-instruct", "type": "closed", "external_link": "https://www.databricks.com/product/machine-learning/large-language-models"}, | |
{"model_name": "mosaicml/mpt-7b-instruct", "type": "open"}, | |
{"model_name": "01-ai/Yi-1.5-9B-Chat", "type": "open"}, | |
{"model_name": "anthropic/claude-3-sonnet", "type": "closed", "external_link": "https://www.anthropic.com/claude"}, | |
{"model_name": "cohere/command-r-plus", "type": "closed", "external_link": "https://cohere.com/models/command-r-plus"}, | |
{"model_name": "meta-llama/llama-3-8b-instruct", "type": "open"} | |
] | |
np.random.seed(42) # For reproducibility | |
model_data = [] | |
for model_info in models: | |
model_name = model_info["model_name"] | |
model_type = model_info["type"] | |
external_link = model_info.get("external_link", None) | |
# Generate random scores | |
average = round(np.random.uniform(40, 90), 2) | |
ifeval = round(np.random.uniform(30, 90), 2) | |
bbhi = round(np.random.uniform(40, 85), 2) | |
math = round(np.random.uniform(20, 80), 2) | |
gpqa = round(np.random.uniform(10, 70), 2) | |
mujb = round(np.random.uniform(10, 70), 2) | |
mmlu = round(np.random.uniform(40, 85), 2) | |
co2_cost = round(np.random.uniform(1, 100), 2) | |
# If it's an open model, it should have a link to Hugging Face | |
model_link = None | |
if external_link: | |
model_link = external_link | |
elif "/" in model_name: | |
model_link = f"https://huggingface.co/{model_name}" | |
else: | |
model_link = f"https://huggingface.co/models?search={model_name}" | |
model_data.append({ | |
"model_name": model_name, | |
"type": model_type, | |
"model_link": model_link, | |
"scores": { | |
"average": average, | |
"ifeval": ifeval, | |
"bbhi": bbhi, | |
"math": math, | |
"gpqa": gpqa, | |
"mujb": mujb, | |
"mmlu": mmlu | |
}, | |
"co2_cost": co2_cost | |
}) | |
# Sort by average score | |
model_data.sort(key=lambda x: x["scores"]["average"], reverse=True) | |
# Create the final data structure | |
from datetime import datetime | |
leaderboard_data = { | |
"models": model_data, | |
"updated_at": datetime.now().isoformat() | |
} | |
return leaderboard_data | |
def read_root(): | |
return {"message": "Welcome to the LLM Leaderboard API"} | |
def get_leaderboard(refresh: bool = Query(False, description="Force refresh data from source")): | |
""" | |
Get the full leaderboard data. | |
If refresh is True, force fetch from source instead of using cache. | |
""" | |
external_data = fetch_external_leaderboard_data(refresh=refresh) | |
if external_data: | |
# Process external data to match our expected format | |
try: | |
# Here you would transform the external data to match LeaderboardData model | |
# This is a simplified example - you'd need to adapt this to the actual structure | |
return external_data | |
except Exception as e: | |
print(f"Error processing external data: {e}") | |
# Fall back to sample data if external data can't be processed | |
return generate_sample_data() | |
def get_models(): | |
"""Get a list of all model names in the leaderboard""" | |
data = fetch_external_leaderboard_data() or generate_sample_data() | |
return [model["model_name"] for model in data["models"]] | |
def get_model_details(model_name: str): | |
"""Get detailed information about a specific model""" | |
data = fetch_external_leaderboard_data() or generate_sample_data() | |
for model in data["models"]: | |
if model["model_name"] == model_name: | |
return model | |
raise HTTPException(status_code=404, detail=f"Model {model_name} not found") | |
def get_filter_counts(): | |
""" | |
Get counts for different filter categories to display in the UI. | |
This matches what's shown in the 'Quick Filters' section of the leaderboard. | |
""" | |
data = fetch_external_leaderboard_data() or generate_sample_data() | |
# Count models by different categories | |
edge_count = 0 | |
consumer_count = 0 | |
midrange_count = 0 | |
gpu_rich_count = 0 | |
official_count = 0 | |
for model in data["models"]: | |
# Edge devices (typically small models) | |
if "scores" in model and model["scores"].get("average", 0) < 45: | |
edge_count += 1 | |
# Consumer (moderate size/performance) | |
if "scores" in model and 45 <= model["scores"].get("average", 0) < 55: | |
consumer_count += 1 | |
# Mid-range | |
if "scores" in model and 55 <= model["scores"].get("average", 0) < 65: | |
midrange_count += 1 | |
# GPU-rich (high-end models) | |
if "scores" in model and model["scores"].get("average", 0) >= 65: | |
gpu_rich_count += 1 | |
# Official providers | |
# This is just a placeholder logic - adapt to your actual criteria | |
if "/" not in model["model_name"] or model["model_name"].startswith("meta/") or model["model_name"].startswith("google/"): | |
official_count += 1 | |
return { | |
"edge_devices": edge_count, | |
"consumers": consumer_count, | |
"midrange": midrange_count, | |
"gpu_rich": gpu_rich_count, | |
"official_providers": official_count | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |