FredOru commited on
Commit
330f0b8
·
1 Parent(s): b609a55

first working draft

Browse files
Files changed (5) hide show
  1. app.py +234 -0
  2. arena.py +204 -0
  3. poetry.lock +0 -0
  4. prompts.csv +5 -0
  5. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ import time
5
+ from threading import Thread
6
+ from arena import PromptArena
7
+
8
+ LABEL_A = "Proposition A"
9
+ LABEL_B = "Proposition B"
10
+
11
+
12
+ class PromptArenaApp:
13
+ """
14
+ Classe pour encapsuler l'arène et gérer l'interface Gradio.
15
+ """
16
+
17
+ def __init__(self, arena: PromptArena) -> None:
18
+ """
19
+ Initialise l'application et charge les prompts depuis le fichier CSV.
20
+ """
21
+ self.arena: PromptArena = arena
22
+
23
+ def select_and_display_match(self):
24
+ """
25
+ Sélectionne un match et l'affiche.
26
+
27
+ Returns:
28
+ Tuple contenant:
29
+ - Le texte du premier prompt
30
+ - Le texte du second prompt
31
+ - Un dictionnaire d'état contenant les IDs des prompts
32
+ """
33
+
34
+ try:
35
+ prompt_a_id, prompt_b_id = self.arena.select_match()
36
+ prompt_a_text = self.arena.prompts.get(prompt_a_id, "")
37
+ prompt_b_text = self.arena.prompts.get(prompt_b_id, "")
38
+
39
+ state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id}
40
+
41
+ return (
42
+ prompt_a_text,
43
+ prompt_b_text,
44
+ state,
45
+ gr.update(interactive=True), # button A
46
+ gr.update(interactive=True), # button B
47
+ gr.update(interactive=False), # match button
48
+ )
49
+ except Exception as e:
50
+ return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
51
+
52
+ def record_winner_a(self, state: dict[str, str]):
53
+ try:
54
+ prompt_a_id = state["prompt_a_id"]
55
+ prompt_b_id = state["prompt_b_id"]
56
+
57
+ self.arena.record_result(
58
+ prompt_a_id, prompt_b_id
59
+ ) # Mettre à jour la progression et le classement
60
+ progress_info = self.get_progress_info()
61
+ rankings_table = self.get_rankings_table()
62
+
63
+ return (
64
+ f"Vous avez choisi : {LABEL_A}",
65
+ progress_info,
66
+ rankings_table,
67
+ gr.update(interactive=False), # button A
68
+ gr.update(interactive=False), # button B
69
+ gr.update(interactive=True), # match button
70
+ )
71
+ except Exception as e:
72
+ return (
73
+ f"Erreur lors de l'enregistrement du résultat: {str(e)}",
74
+ "",
75
+ pd.DataFrame(),
76
+ )
77
+
78
+ def record_winner_b(self, state: dict[str, str]):
79
+ try:
80
+ prompt_a_id = state["prompt_a_id"]
81
+ prompt_b_id = state["prompt_b_id"]
82
+
83
+ self.arena.record_result(
84
+ prompt_b_id, prompt_a_id
85
+ ) # Mettre à jour la progression et le classement
86
+ progress_info = self.get_progress_info()
87
+ rankings_table = self.get_rankings_table()
88
+
89
+ return (
90
+ f"Vous avez choisi : {LABEL_B}",
91
+ progress_info,
92
+ rankings_table,
93
+ gr.update(interactive=False), # button A
94
+ gr.update(interactive=False), # button B
95
+ gr.update(interactive=True), # match button
96
+ )
97
+ except Exception as e:
98
+ return (
99
+ f"Erreur lors de l'enregistrement du résultat: {str(e)}",
100
+ "",
101
+ pd.DataFrame(),
102
+ )
103
+
104
+ def get_progress_info(self) -> str:
105
+ """
106
+ Obtient les informations sur la progression du tournoi.
107
+
108
+ Returns:
109
+ str: Message formaté contenant les statistiques de progression
110
+ """
111
+ if not self.arena:
112
+ return "Aucune arène initialisée. Veuillez d'abord charger des prompts."
113
+
114
+ try:
115
+ progress = self.arena.get_progress()
116
+
117
+ info = f"Prompts: {progress['total_prompts']}\n"
118
+ info += f"Matchs joués: {progress['total_matches']}\n"
119
+ info += f"Progression: {progress['progress']:.2f}%\n"
120
+ info += (
121
+ f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n"
122
+ )
123
+ info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}"
124
+
125
+ return info
126
+ except Exception as e:
127
+ return f"Erreur lors de la récupération de la progression: {str(e)}"
128
+
129
+ def get_rankings_table(self) -> pd.DataFrame:
130
+ """
131
+ Obtient le classement des prompts sous forme de tableau.
132
+
133
+ Returns:
134
+ pd.DataFrame: Tableau de classement des prompts
135
+ """
136
+ if not self.arena:
137
+ return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}])
138
+
139
+ try:
140
+ rankings = self.arena.get_rankings()
141
+
142
+ df = pd.DataFrame(rankings)
143
+ df = df[["rank", "prompt_id", "score"]]
144
+ df = df.rename(
145
+ columns={
146
+ "rank": "Rang",
147
+ "prompt_id": "ID",
148
+ "score": "Score",
149
+ }
150
+ )
151
+
152
+ return df
153
+ except Exception as e:
154
+ return pd.DataFrame([{"Erreur": str(e)}])
155
+
156
+ def create_ui(self) -> gr.Blocks:
157
+ """
158
+ Crée l'interface utilisateur Gradio.
159
+
160
+ Returns:
161
+ gr.Blocks: L'application Gradio configurée
162
+ """
163
+
164
+ with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app:
165
+ gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>')
166
+
167
+ with gr.Row():
168
+ select_btn = gr.Button("Lancer un nouveau match", variant="primary")
169
+
170
+ with gr.Row():
171
+ proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
172
+ proposition_b = gr.Textbox(label=LABEL_B, interactive=False)
173
+
174
+ with gr.Row():
175
+ vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
176
+ vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)
177
+
178
+ result = gr.Textbox("Résultat", interactive=False)
179
+ progress_info = gr.Textbox(
180
+ label="Progression du concours", interactive=False
181
+ )
182
+ rankings_table = gr.DataFrame(label="Classement des prompts")
183
+ state = gr.State() # contient les IDs des prompts du match en cours
184
+
185
+ select_btn.click(
186
+ self.select_and_display_match,
187
+ inputs=[],
188
+ outputs=[
189
+ proposition_a,
190
+ proposition_b,
191
+ state,
192
+ vote_a_btn,
193
+ vote_b_btn,
194
+ select_btn,
195
+ ],
196
+ )
197
+ vote_a_btn.click(
198
+ self.record_winner_a,
199
+ inputs=[state],
200
+ outputs=[
201
+ result,
202
+ progress_info,
203
+ rankings_table,
204
+ vote_a_btn,
205
+ vote_b_btn,
206
+ select_btn,
207
+ ],
208
+ )
209
+ vote_b_btn.click(
210
+ self.record_winner_b,
211
+ inputs=[state],
212
+ outputs=[
213
+ result,
214
+ progress_info,
215
+ rankings_table,
216
+ vote_a_btn,
217
+ vote_b_btn,
218
+ select_btn,
219
+ ],
220
+ )
221
+
222
+ gr.Row([progress_info, rankings_table])
223
+
224
+ return app
225
+
226
+
227
+ # Exemple d'utilisation
228
+ if __name__ == "__main__":
229
+ # load the prompts from the CSV file
230
+ prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist()
231
+ arena = PromptArena(prompts=prompts)
232
+ app_instance = PromptArenaApp(arena=arena)
233
+ app = app_instance.create_ui()
234
+ app.launch()
arena.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trueskill
2
+ import random
3
+ import json
4
+ import os
5
+ import datetime
6
+ from typing import Dict, List, Tuple, Union
7
+
8
+
9
+ class PromptArena:
10
+ """
11
+ Une arène pour comparer et classer des prompts en utilisant l'algorithme TrueSkill.
12
+
13
+ Cette classe permet d'organiser des "matchs" entre prompts où des utilisateurs
14
+ choisissent leur préféré, mettant à jour les classements TrueSkill en fonction
15
+ des résultats.
16
+ """
17
+
18
+ def __init__(
19
+ self, prompts: List[str], results_file: str = "ratings_archive.json"
20
+ ) -> None:
21
+ """
22
+ Initialise une arène de prompts.
23
+
24
+ Args:
25
+ prompts: Liste de textes de prompts
26
+ results_file: Chemin du fichier pour sauvegarder/charger les ratings
27
+ """
28
+ self.prompts: Dict[str, str] = {
29
+ str(idx + 1): prompt for idx, prompt in enumerate(prompts)
30
+ }
31
+ self.ratings: Dict[str, trueskill.Rating] = {} # {prompt_id: trueskill.Rating}
32
+ self.results_file: str = results_file
33
+ self.match_history: List[Dict[str, str]] = []
34
+
35
+ # Charger les ratings si le fichier existe
36
+ self._load_ratings()
37
+
38
+ # Initialiser les ratings pour les nouveaux prompts
39
+ for prompt_id in self.prompts:
40
+ if prompt_id not in self.ratings:
41
+ self.ratings[prompt_id] = trueskill.Rating()
42
+
43
+ def _load_ratings(self) -> None:
44
+ """Charge les ratings depuis un fichier JSON si disponible"""
45
+ if os.path.exists(self.results_file):
46
+ with open(self.results_file, "r", encoding="utf-8") as f:
47
+ data = json.load(f)
48
+
49
+ # Convertir les données stockées en objets trueskill.Rating
50
+ for prompt_id, rating_data in data["ratings"].items():
51
+ self.ratings[prompt_id] = trueskill.Rating(
52
+ mu=rating_data["mu"], sigma=rating_data["sigma"]
53
+ )
54
+
55
+ self.match_history = data.get("match_history", [])
56
+
57
+ def _save_ratings(self) -> None:
58
+ """Sauvegarde les ratings et l'historique dans un fichier JSON"""
59
+ data = {
60
+ "ratings": {
61
+ prompt_id: {"mu": rating.mu, "sigma": rating.sigma}
62
+ for prompt_id, rating in self.ratings.items()
63
+ },
64
+ "match_history": self.match_history,
65
+ }
66
+
67
+ with open(self.results_file, "w", encoding="utf-8") as f:
68
+ json.dump(data, f, ensure_ascii=False, indent=2)
69
+
70
+ def add_prompt(self, prompt_id: str, prompt_text: str) -> None:
71
+ """
72
+ Ajoute un nouveau prompt à l'arène.
73
+
74
+ Args:
75
+ prompt_id: Identifiant unique du prompt
76
+ prompt_text: Texte du prompt
77
+ """
78
+ self.prompts[prompt_id] = prompt_text
79
+ if prompt_id not in self.ratings:
80
+ self.ratings[prompt_id] = trueskill.Rating()
81
+ self._save_ratings()
82
+
83
+ def select_match(self) -> Tuple[str, str]:
84
+ """
85
+ Sélectionne deux prompts pour un match en privilégiant ceux avec une grande incertitude.
86
+
87
+ La stratégie est de sélectionner d'abord le prompt avec la plus grande incertitude (sigma),
88
+ puis de trouver un adversaire avec un niveau (mu) similaire.
89
+
90
+ Returns:
91
+ Un tuple contenant les IDs des deux prompts à comparer (prompt_a, prompt_b)
92
+ """
93
+ # Stratégie: choisir des prompts avec sigma élevé et des niveaux similaires
94
+ prompt_ids = list(self.prompts.keys())
95
+
96
+ # Trier par incertitude (sigma) décroissante
97
+ prompt_ids.sort(key=lambda pid: self.ratings[pid].sigma, reverse=True)
98
+
99
+ # Sélectionner le premier prompt (plus grande incertitude)
100
+ prompt_a = prompt_ids[0]
101
+
102
+ # Pour le second, trouver un prompt proche en niveau (mu)
103
+ mu_a = self.ratings[prompt_a].mu
104
+
105
+ # Trier les prompts restants par proximité de mu
106
+ remaining_prompts = [p for p in prompt_ids if p != prompt_a]
107
+ remaining_prompts.sort(key=lambda pid: abs(self.ratings[pid].mu - mu_a))
108
+
109
+ # Prendre un prompt parmi les 3 plus proches (avec un peu de randomisation)
110
+ top_n = min(3, len(remaining_prompts))
111
+ prompt_b = random.choice(remaining_prompts[:top_n])
112
+
113
+ return prompt_a, prompt_b
114
+
115
+ def record_result(self, winner_id: str, loser_id: str) -> None:
116
+ """
117
+ Enregistre le résultat d'un match et met à jour les ratings.
118
+
119
+ Args:
120
+ winner_id: ID du prompt gagnant
121
+ loser_id: ID du prompt perdant
122
+ """
123
+ # Obtenir les ratings actuels
124
+ winner_rating = self.ratings[winner_id]
125
+ loser_rating = self.ratings[loser_id]
126
+
127
+ # Mettre à jour les ratings (TrueSkill s'occupe des calculs)
128
+ self.ratings[winner_id], self.ratings[loser_id] = trueskill.rate_1vs1(
129
+ winner_rating, loser_rating
130
+ )
131
+
132
+ # Enregistrer le match dans l'historique
133
+ self.match_history.append(
134
+ {
135
+ "winner": winner_id,
136
+ "loser": loser_id,
137
+ "timestamp": str(datetime.datetime.now()),
138
+ }
139
+ )
140
+
141
+ # Sauvegarder les résultats
142
+ self._save_ratings()
143
+
144
+ def get_rankings(self) -> List[Dict[str, Union[int, str, float]]]:
145
+ """
146
+ Obtient le classement actuel des prompts.
147
+
148
+ Returns:
149
+ Liste de dictionnaires contenant le classement de chaque prompt avec
150
+ ses informations (rang, id, texte, mu, sigma, score)
151
+ """
152
+ # Trier les prompts par "conserved expected score" = mu - 3*sigma
153
+ # (une façon conservatrice d'estimer la compétence en tenant compte de l'incertitude)
154
+ sorted_prompts = sorted(
155
+ self.ratings.items(), key=lambda x: x[1].mu - 3 * x[1].sigma, reverse=True
156
+ )
157
+
158
+ rankings = []
159
+ for i, (prompt_id, rating) in enumerate(sorted_prompts, 1):
160
+ prompt_text = self.prompts.get(prompt_id, "Prompt inconnu")
161
+ rankings.append(
162
+ {
163
+ "rank": i,
164
+ "prompt_id": prompt_id,
165
+ "prompt": prompt_text,
166
+ "mu": rating.mu,
167
+ "sigma": rating.sigma,
168
+ "score": rating.mu - 3 * rating.sigma, # Score conservateur
169
+ }
170
+ )
171
+
172
+ return rankings
173
+
174
+ def get_progress(self) -> Dict[str, Union[int, float]]:
175
+ """
176
+ Renvoie des statistiques sur la progression du tournoi.
177
+
178
+ Returns:
179
+ Dictionnaire contenant des informations sur la progression:
180
+ - total_prompts: nombre total de prompts
181
+ - total_matches: nombre total de matchs joués
182
+ - avg_sigma: incertitude moyenne des ratings
183
+ - progress: pourcentage estimé de progression du tournoi
184
+ - estimated_remaining_matches: estimation du nombre de matchs restants
185
+ """
186
+ total_prompts = len(self.prompts)
187
+ total_matches = len(self.match_history)
188
+
189
+ avg_sigma = sum(r.sigma for r in self.ratings.values()) / max(
190
+ 1, len(self.ratings)
191
+ )
192
+
193
+ # Estimer quel pourcentage du tournoi est complété
194
+ # En se basant sur la réduction moyenne de sigma par rapport à la valeur initiale
195
+ initial_sigma = trueskill.Rating().sigma
196
+ progress = min(100, max(0, (1 - avg_sigma / initial_sigma) * 100))
197
+
198
+ return {
199
+ "total_prompts": total_prompts,
200
+ "total_matches": total_matches,
201
+ "avg_sigma": avg_sigma,
202
+ "progress": progress,
203
+ "estimated_remaining_matches": int(total_prompts * 15) - total_matches,
204
+ }
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
prompts.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Prompt A 100
2
+ Prompt B 90
3
+ Prompt C 60
4
+ Prompt D 55
5
+ Prompt E 30
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ trueskill>=0.4.5,<0.5.0
2
+ gradio>=4.0.0
3
+ pandas>=2.0.0