Arena / arena.py
FredOru's picture
(feat) db csv et admin panels
1bc7a0c
import trueskill as ts
import pandas as pd
import random
import datetime
from typing import Dict, List, Tuple, Union
import db
import pandas as pd
from typing import TypedDict
MU_init = ts.Rating().mu
SIGMA_init = ts.Rating().sigma
class Prompt(TypedDict):
id: int
name: str
text: str
class Arena:
"""
Une arène pour comparer et classer des prompts en utilisant l'algorithme TrueSkill.
"""
def init_estimates(self, reboot=True) -> None:
"""
Initialise les estimations des prompts avec des ratings TrueSkill par défaut.
reboot : si le fichier estimates.csv existe déjà, on le laisse tel quel.
"""
estimates = db.load("estimates")
if not estimates.empty and reboot:
return None
if estimates.empty:
for i in db.load("prompts")["id"].to_list():
db.insert(
"estimates",
{
"prompt_id": i,
"mu": MU_init,
"sigma": SIGMA_init,
},
)
def load(self, table_name: str) -> pd.DataFrame:
"""
fonction back pour l'UI.
Charge les données d'une table depuis le fichier CSV.
"""
return db.load(table_name)
def replace(self, table_name: str, df: pd.DataFrame) -> pd.DataFrame:
"""
fonction back pour l'UI.
Remplace le contenu d'une table par les données du fichier CSV.
Pour l'admin uniquement
"""
return db.replace(table_name, df)
def select_match(self) -> Tuple[Prompt, Prompt]:
"""
Sélectionne deux prompts pour un match en privilégiant ceux avec une grande incertitude.
Returns:
Un tuple contenant les IDs des deux prompts à comparer (prompt_a, prompt_b)
"""
# le prompt le plus incertain (sigma le plus élevé)
estimates = db.load("estimates")
estimate_a = estimates.sort_values(by="sigma", ascending=False).iloc[0]
# le prompt le plus proche en niveau (mu) du prompt_a
estimate_b = (
estimates.loc[estimates["prompt_id"] != estimate_a["prompt_id"]]
.assign(delta_mu=lambda df_: abs(df_["mu"] - estimate_a["mu"]))
.sort_values(by="delta_mu", ascending=True)
.iloc[0]
)
prompts = db.load("prompts")
prompt_a = prompts.query(f"id == {estimate_a['prompt_id']}").iloc[0].to_dict()
prompt_b = prompts.query(f"id == {estimate_b['prompt_id']}").iloc[0].to_dict()
# We need to update the selection strategy to prefer prompts with high uncertainty
# but also consider prompts that are close in ranking (within 5 positions)
# Create pairs of prompts that are at most 5 positions apart in the ranking
# close_pairs = []
# for i in range(len(prompt_ids)):
# for j in range(i + 1, min(i + 6, len(prompt_ids))):
# close_pairs.append((prompt_ids[i], prompt_ids[j]))
return prompt_a, prompt_b
def record_result(self, winner_id: str, loser_id: str) -> None:
# Obtenir les ratings actuels
estimates = db.load("estimates")
winner_estimate = (
estimates[estimates["prompt_id"] == winner_id].iloc[0].to_dict()
)
loser_estimate = estimates[estimates["prompt_id"] == loser_id].iloc[0].to_dict()
winner_rating = ts.Rating(winner_estimate["mu"], winner_estimate["sigma"])
loser_rating = ts.Rating(loser_estimate["mu"], loser_estimate["sigma"])
winner_new_rating, loser_new_rating = ts.rate_1vs1(winner_rating, loser_rating)
db.update(
"estimates",
winner_estimate["id"],
{"mu": winner_new_rating.mu, "sigma": winner_new_rating.sigma},
)
db.update(
"estimates",
loser_estimate["id"],
{"mu": loser_new_rating.mu, "sigma": loser_new_rating.sigma},
)
db.insert(
"votes",
{
"winner_id": winner_id,
"loser_id": loser_id,
# "timestamp": datetime.datetime.now().isoformat(),
},
)
return None
def get_rankings(self) -> pd.DataFrame:
"""
Obtient le classement actuel des prompts.
Returns:
Liste de dictionnaires contenant le classement de chaque prompt avec
ses informations (rang, id, texte, mu, sigma, score)
"""
prompts = db.load("prompts")
estimates = db.load("estimates").drop(columns=["id"])
rankings = prompts.merge(estimates, left_on="id", right_on="prompt_id").drop(
columns=["id", "prompt_id"]
)
return rankings.sort_values(by="mu", ascending=False)
# eventuellement afficher plutôt mu - 3 sigma pour être conservateur
def get_progress(self) -> str:
"""
Renvoie des statistiques sur la progression du tournoi.
Returns:
Dictionnaire contenant des informations sur la progression:
- total_prompts: nombre total de prompts
- total_matches: nombre total de matchs joués
- avg_sigma: incertitude moyenne des ratings
- progress: pourcentage estimé de progression du tournoi
- estimated_remaining_matches: estimation du nombre de matchs restants
"""
prompts = db.load("prompts")
estimates = db.load("estimates")
votes = db.load("votes")
avg_sigma = estimates["sigma"].mean()
# Estimer quel pourcentage du tournoi est complété
# En se basant sur la réduction moyenne de sigma par rapport à la valeur initiale
initial_sigma = ts.Rating().sigma
progress = min(100, max(0, (1 - avg_sigma / initial_sigma) * 100))
msg = f"""{len(prompts)} propositions à départager
{len(votes)} matchs joués
{avg_sigma:.2f} d'incertitude moyenne"""
return msg