Arena / arena.py
FredOru's picture
first working draft
330f0b8
raw
history blame contribute delete
7.65 kB
import trueskill
import random
import json
import os
import datetime
from typing import Dict, List, Tuple, Union
class PromptArena:
"""
Une arène pour comparer et classer des prompts en utilisant l'algorithme TrueSkill.
Cette classe permet d'organiser des "matchs" entre prompts où des utilisateurs
choisissent leur préféré, mettant à jour les classements TrueSkill en fonction
des résultats.
"""
def __init__(
self, prompts: List[str], results_file: str = "ratings_archive.json"
) -> None:
"""
Initialise une arène de prompts.
Args:
prompts: Liste de textes de prompts
results_file: Chemin du fichier pour sauvegarder/charger les ratings
"""
self.prompts: Dict[str, str] = {
str(idx + 1): prompt for idx, prompt in enumerate(prompts)
}
self.ratings: Dict[str, trueskill.Rating] = {} # {prompt_id: trueskill.Rating}
self.results_file: str = results_file
self.match_history: List[Dict[str, str]] = []
# Charger les ratings si le fichier existe
self._load_ratings()
# Initialiser les ratings pour les nouveaux prompts
for prompt_id in self.prompts:
if prompt_id not in self.ratings:
self.ratings[prompt_id] = trueskill.Rating()
def _load_ratings(self) -> None:
"""Charge les ratings depuis un fichier JSON si disponible"""
if os.path.exists(self.results_file):
with open(self.results_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Convertir les données stockées en objets trueskill.Rating
for prompt_id, rating_data in data["ratings"].items():
self.ratings[prompt_id] = trueskill.Rating(
mu=rating_data["mu"], sigma=rating_data["sigma"]
)
self.match_history = data.get("match_history", [])
def _save_ratings(self) -> None:
"""Sauvegarde les ratings et l'historique dans un fichier JSON"""
data = {
"ratings": {
prompt_id: {"mu": rating.mu, "sigma": rating.sigma}
for prompt_id, rating in self.ratings.items()
},
"match_history": self.match_history,
}
with open(self.results_file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def add_prompt(self, prompt_id: str, prompt_text: str) -> None:
"""
Ajoute un nouveau prompt à l'arène.
Args:
prompt_id: Identifiant unique du prompt
prompt_text: Texte du prompt
"""
self.prompts[prompt_id] = prompt_text
if prompt_id not in self.ratings:
self.ratings[prompt_id] = trueskill.Rating()
self._save_ratings()
def select_match(self) -> Tuple[str, str]:
"""
Sélectionne deux prompts pour un match en privilégiant ceux avec une grande incertitude.
La stratégie est de sélectionner d'abord le prompt avec la plus grande incertitude (sigma),
puis de trouver un adversaire avec un niveau (mu) similaire.
Returns:
Un tuple contenant les IDs des deux prompts à comparer (prompt_a, prompt_b)
"""
# Stratégie: choisir des prompts avec sigma élevé et des niveaux similaires
prompt_ids = list(self.prompts.keys())
# Trier par incertitude (sigma) décroissante
prompt_ids.sort(key=lambda pid: self.ratings[pid].sigma, reverse=True)
# Sélectionner le premier prompt (plus grande incertitude)
prompt_a = prompt_ids[0]
# Pour le second, trouver un prompt proche en niveau (mu)
mu_a = self.ratings[prompt_a].mu
# Trier les prompts restants par proximité de mu
remaining_prompts = [p for p in prompt_ids if p != prompt_a]
remaining_prompts.sort(key=lambda pid: abs(self.ratings[pid].mu - mu_a))
# Prendre un prompt parmi les 3 plus proches (avec un peu de randomisation)
top_n = min(3, len(remaining_prompts))
prompt_b = random.choice(remaining_prompts[:top_n])
return prompt_a, prompt_b
def record_result(self, winner_id: str, loser_id: str) -> None:
"""
Enregistre le résultat d'un match et met à jour les ratings.
Args:
winner_id: ID du prompt gagnant
loser_id: ID du prompt perdant
"""
# Obtenir les ratings actuels
winner_rating = self.ratings[winner_id]
loser_rating = self.ratings[loser_id]
# Mettre à jour les ratings (TrueSkill s'occupe des calculs)
self.ratings[winner_id], self.ratings[loser_id] = trueskill.rate_1vs1(
winner_rating, loser_rating
)
# Enregistrer le match dans l'historique
self.match_history.append(
{
"winner": winner_id,
"loser": loser_id,
"timestamp": str(datetime.datetime.now()),
}
)
# Sauvegarder les résultats
self._save_ratings()
def get_rankings(self) -> List[Dict[str, Union[int, str, float]]]:
"""
Obtient le classement actuel des prompts.
Returns:
Liste de dictionnaires contenant le classement de chaque prompt avec
ses informations (rang, id, texte, mu, sigma, score)
"""
# Trier les prompts par "conserved expected score" = mu - 3*sigma
# (une façon conservatrice d'estimer la compétence en tenant compte de l'incertitude)
sorted_prompts = sorted(
self.ratings.items(), key=lambda x: x[1].mu - 3 * x[1].sigma, reverse=True
)
rankings = []
for i, (prompt_id, rating) in enumerate(sorted_prompts, 1):
prompt_text = self.prompts.get(prompt_id, "Prompt inconnu")
rankings.append(
{
"rank": i,
"prompt_id": prompt_id,
"prompt": prompt_text,
"mu": rating.mu,
"sigma": rating.sigma,
"score": rating.mu - 3 * rating.sigma, # Score conservateur
}
)
return rankings
def get_progress(self) -> Dict[str, Union[int, float]]:
"""
Renvoie des statistiques sur la progression du tournoi.
Returns:
Dictionnaire contenant des informations sur la progression:
- total_prompts: nombre total de prompts
- total_matches: nombre total de matchs joués
- avg_sigma: incertitude moyenne des ratings
- progress: pourcentage estimé de progression du tournoi
- estimated_remaining_matches: estimation du nombre de matchs restants
"""
total_prompts = len(self.prompts)
total_matches = len(self.match_history)
avg_sigma = sum(r.sigma for r in self.ratings.values()) / max(
1, len(self.ratings)
)
# Estimer quel pourcentage du tournoi est complété
# En se basant sur la réduction moyenne de sigma par rapport à la valeur initiale
initial_sigma = trueskill.Rating().sigma
progress = min(100, max(0, (1 - avg_sigma / initial_sigma) * 100))
return {
"total_prompts": total_prompts,
"total_matches": total_matches,
"avg_sigma": avg_sigma,
"progress": progress,
"estimated_remaining_matches": int(total_prompts * 15) - total_matches,
}