File size: 7,688 Bytes
330f0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import gradio as gr
import pandas as pd
import os
import time
from threading import Thread
from arena import PromptArena

LABEL_A = "Proposition A"
LABEL_B = "Proposition B"


class PromptArenaApp:
    """
    Classe pour encapsuler l'arène et gérer l'interface Gradio.
    """

    def __init__(self, arena: PromptArena) -> None:
        """
        Initialise l'application et charge les prompts depuis le fichier CSV.
        """
        self.arena: PromptArena = arena

    def select_and_display_match(self):
        """
        Sélectionne un match et l'affiche.

        Returns:
            Tuple contenant:
            - Le texte du premier prompt
            - Le texte du second prompt
            - Un dictionnaire d'état contenant les IDs des prompts
        """

        try:
            prompt_a_id, prompt_b_id = self.arena.select_match()
            prompt_a_text = self.arena.prompts.get(prompt_a_id, "")
            prompt_b_text = self.arena.prompts.get(prompt_b_id, "")

            state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id}

            return (
                prompt_a_text,
                prompt_b_text,
                state,
                gr.update(interactive=True),  # button A
                gr.update(interactive=True),  # button B
                gr.update(interactive=False),  # match button
            )
        except Exception as e:
            return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}

    def record_winner_a(self, state: dict[str, str]):
        try:
            prompt_a_id = state["prompt_a_id"]
            prompt_b_id = state["prompt_b_id"]

            self.arena.record_result(
                prompt_a_id, prompt_b_id
            )  # Mettre à jour la progression et le classement
            progress_info = self.get_progress_info()
            rankings_table = self.get_rankings_table()

            return (
                f"Vous avez choisi : {LABEL_A}",
                progress_info,
                rankings_table,
                gr.update(interactive=False),  # button A
                gr.update(interactive=False),  # button B
                gr.update(interactive=True),  # match button
            )
        except Exception as e:
            return (
                f"Erreur lors de l'enregistrement du résultat: {str(e)}",
                "",
                pd.DataFrame(),
            )

    def record_winner_b(self, state: dict[str, str]):
        try:
            prompt_a_id = state["prompt_a_id"]
            prompt_b_id = state["prompt_b_id"]

            self.arena.record_result(
                prompt_b_id, prompt_a_id
            )  # Mettre à jour la progression et le classement
            progress_info = self.get_progress_info()
            rankings_table = self.get_rankings_table()

            return (
                f"Vous avez choisi : {LABEL_B}",
                progress_info,
                rankings_table,
                gr.update(interactive=False),  # button A
                gr.update(interactive=False),  # button B
                gr.update(interactive=True),  # match button
            )
        except Exception as e:
            return (
                f"Erreur lors de l'enregistrement du résultat: {str(e)}",
                "",
                pd.DataFrame(),
            )

    def get_progress_info(self) -> str:
        """
        Obtient les informations sur la progression du tournoi.

        Returns:
            str: Message formaté contenant les statistiques de progression
        """
        if not self.arena:
            return "Aucune arène initialisée. Veuillez d'abord charger des prompts."

        try:
            progress = self.arena.get_progress()

            info = f"Prompts: {progress['total_prompts']}\n"
            info += f"Matchs joués: {progress['total_matches']}\n"
            info += f"Progression: {progress['progress']:.2f}%\n"
            info += (
                f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n"
            )
            info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}"

            return info
        except Exception as e:
            return f"Erreur lors de la récupération de la progression: {str(e)}"

    def get_rankings_table(self) -> pd.DataFrame:
        """
        Obtient le classement des prompts sous forme de tableau.

        Returns:
            pd.DataFrame: Tableau de classement des prompts
        """
        if not self.arena:
            return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}])

        try:
            rankings = self.arena.get_rankings()

            df = pd.DataFrame(rankings)
            df = df[["rank", "prompt_id", "score"]]
            df = df.rename(
                columns={
                    "rank": "Rang",
                    "prompt_id": "ID",
                    "score": "Score",
                }
            )

            return df
        except Exception as e:
            return pd.DataFrame([{"Erreur": str(e)}])

    def create_ui(self) -> gr.Blocks:
        """
        Crée l'interface utilisateur Gradio.

        Returns:
            gr.Blocks: L'application Gradio configurée
        """

        with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app:
            gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>')

            with gr.Row():
                select_btn = gr.Button("Lancer un nouveau match", variant="primary")

            with gr.Row():
                proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
                proposition_b = gr.Textbox(label=LABEL_B, interactive=False)

            with gr.Row():
                vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
                vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)

            result = gr.Textbox("Résultat", interactive=False)
            progress_info = gr.Textbox(
                label="Progression du concours", interactive=False
            )
            rankings_table = gr.DataFrame(label="Classement des prompts")
            state = gr.State()  # contient les IDs des prompts du match en cours

            select_btn.click(
                self.select_and_display_match,
                inputs=[],
                outputs=[
                    proposition_a,
                    proposition_b,
                    state,
                    vote_a_btn,
                    vote_b_btn,
                    select_btn,
                ],
            )
            vote_a_btn.click(
                self.record_winner_a,
                inputs=[state],
                outputs=[
                    result,
                    progress_info,
                    rankings_table,
                    vote_a_btn,
                    vote_b_btn,
                    select_btn,
                ],
            )
            vote_b_btn.click(
                self.record_winner_b,
                inputs=[state],
                outputs=[
                    result,
                    progress_info,
                    rankings_table,
                    vote_a_btn,
                    vote_b_btn,
                    select_btn,
                ],
            )

            gr.Row([progress_info, rankings_table])

        return app


# Exemple d'utilisation
if __name__ == "__main__":
    # load the prompts from the CSV file
    prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist()
    arena = PromptArena(prompts=prompts)
    app_instance = PromptArenaApp(arena=arena)
    app = app_instance.create_ui()
    app.launch()