File size: 7,647 Bytes
330f0b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import trueskill
import random
import json
import os
import datetime
from typing import Dict, List, Tuple, Union
class PromptArena:
"""
Une arène pour comparer et classer des prompts en utilisant l'algorithme TrueSkill.
Cette classe permet d'organiser des "matchs" entre prompts où des utilisateurs
choisissent leur préféré, mettant à jour les classements TrueSkill en fonction
des résultats.
"""
def __init__(
self, prompts: List[str], results_file: str = "ratings_archive.json"
) -> None:
"""
Initialise une arène de prompts.
Args:
prompts: Liste de textes de prompts
results_file: Chemin du fichier pour sauvegarder/charger les ratings
"""
self.prompts: Dict[str, str] = {
str(idx + 1): prompt for idx, prompt in enumerate(prompts)
}
self.ratings: Dict[str, trueskill.Rating] = {} # {prompt_id: trueskill.Rating}
self.results_file: str = results_file
self.match_history: List[Dict[str, str]] = []
# Charger les ratings si le fichier existe
self._load_ratings()
# Initialiser les ratings pour les nouveaux prompts
for prompt_id in self.prompts:
if prompt_id not in self.ratings:
self.ratings[prompt_id] = trueskill.Rating()
def _load_ratings(self) -> None:
"""Charge les ratings depuis un fichier JSON si disponible"""
if os.path.exists(self.results_file):
with open(self.results_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Convertir les données stockées en objets trueskill.Rating
for prompt_id, rating_data in data["ratings"].items():
self.ratings[prompt_id] = trueskill.Rating(
mu=rating_data["mu"], sigma=rating_data["sigma"]
)
self.match_history = data.get("match_history", [])
def _save_ratings(self) -> None:
"""Sauvegarde les ratings et l'historique dans un fichier JSON"""
data = {
"ratings": {
prompt_id: {"mu": rating.mu, "sigma": rating.sigma}
for prompt_id, rating in self.ratings.items()
},
"match_history": self.match_history,
}
with open(self.results_file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def add_prompt(self, prompt_id: str, prompt_text: str) -> None:
"""
Ajoute un nouveau prompt à l'arène.
Args:
prompt_id: Identifiant unique du prompt
prompt_text: Texte du prompt
"""
self.prompts[prompt_id] = prompt_text
if prompt_id not in self.ratings:
self.ratings[prompt_id] = trueskill.Rating()
self._save_ratings()
def select_match(self) -> Tuple[str, str]:
"""
Sélectionne deux prompts pour un match en privilégiant ceux avec une grande incertitude.
La stratégie est de sélectionner d'abord le prompt avec la plus grande incertitude (sigma),
puis de trouver un adversaire avec un niveau (mu) similaire.
Returns:
Un tuple contenant les IDs des deux prompts à comparer (prompt_a, prompt_b)
"""
# Stratégie: choisir des prompts avec sigma élevé et des niveaux similaires
prompt_ids = list(self.prompts.keys())
# Trier par incertitude (sigma) décroissante
prompt_ids.sort(key=lambda pid: self.ratings[pid].sigma, reverse=True)
# Sélectionner le premier prompt (plus grande incertitude)
prompt_a = prompt_ids[0]
# Pour le second, trouver un prompt proche en niveau (mu)
mu_a = self.ratings[prompt_a].mu
# Trier les prompts restants par proximité de mu
remaining_prompts = [p for p in prompt_ids if p != prompt_a]
remaining_prompts.sort(key=lambda pid: abs(self.ratings[pid].mu - mu_a))
# Prendre un prompt parmi les 3 plus proches (avec un peu de randomisation)
top_n = min(3, len(remaining_prompts))
prompt_b = random.choice(remaining_prompts[:top_n])
return prompt_a, prompt_b
def record_result(self, winner_id: str, loser_id: str) -> None:
"""
Enregistre le résultat d'un match et met à jour les ratings.
Args:
winner_id: ID du prompt gagnant
loser_id: ID du prompt perdant
"""
# Obtenir les ratings actuels
winner_rating = self.ratings[winner_id]
loser_rating = self.ratings[loser_id]
# Mettre à jour les ratings (TrueSkill s'occupe des calculs)
self.ratings[winner_id], self.ratings[loser_id] = trueskill.rate_1vs1(
winner_rating, loser_rating
)
# Enregistrer le match dans l'historique
self.match_history.append(
{
"winner": winner_id,
"loser": loser_id,
"timestamp": str(datetime.datetime.now()),
}
)
# Sauvegarder les résultats
self._save_ratings()
def get_rankings(self) -> List[Dict[str, Union[int, str, float]]]:
"""
Obtient le classement actuel des prompts.
Returns:
Liste de dictionnaires contenant le classement de chaque prompt avec
ses informations (rang, id, texte, mu, sigma, score)
"""
# Trier les prompts par "conserved expected score" = mu - 3*sigma
# (une façon conservatrice d'estimer la compétence en tenant compte de l'incertitude)
sorted_prompts = sorted(
self.ratings.items(), key=lambda x: x[1].mu - 3 * x[1].sigma, reverse=True
)
rankings = []
for i, (prompt_id, rating) in enumerate(sorted_prompts, 1):
prompt_text = self.prompts.get(prompt_id, "Prompt inconnu")
rankings.append(
{
"rank": i,
"prompt_id": prompt_id,
"prompt": prompt_text,
"mu": rating.mu,
"sigma": rating.sigma,
"score": rating.mu - 3 * rating.sigma, # Score conservateur
}
)
return rankings
def get_progress(self) -> Dict[str, Union[int, float]]:
"""
Renvoie des statistiques sur la progression du tournoi.
Returns:
Dictionnaire contenant des informations sur la progression:
- total_prompts: nombre total de prompts
- total_matches: nombre total de matchs joués
- avg_sigma: incertitude moyenne des ratings
- progress: pourcentage estimé de progression du tournoi
- estimated_remaining_matches: estimation du nombre de matchs restants
"""
total_prompts = len(self.prompts)
total_matches = len(self.match_history)
avg_sigma = sum(r.sigma for r in self.ratings.values()) / max(
1, len(self.ratings)
)
# Estimer quel pourcentage du tournoi est complété
# En se basant sur la réduction moyenne de sigma par rapport à la valeur initiale
initial_sigma = trueskill.Rating().sigma
progress = min(100, max(0, (1 - avg_sigma / initial_sigma) * 100))
return {
"total_prompts": total_prompts,
"total_matches": total_matches,
"avg_sigma": avg_sigma,
"progress": progress,
"estimated_remaining_matches": int(total_prompts * 15) - total_matches,
}
|