File size: 4,190 Bytes
1921336
 
 
 
 
 
 
42c830b
f6590f0
9a1ab03
a797164
 
488c5c4
 
 
 
 
 
3f9babb
 
96b2814
 
3f9babb
 
de53991
34ffea3
 
 
 
f6590f0
1921336
488c5c4
4ea28ea
 
 
 
488c5c4
ed67a17
488c5c4
 
 
 
 
 
fa738bd
488c5c4
c0a9946
 
64df9ac
1921336
 
64df9ac
1921336
64df9ac
1921336
 
 
 
 
 
 
 
 
 
 
 
 
9a1ab03
f664ce2
 
 
 
3f9babb
 
 
 
f664ce2
9e26af4
96b2814
64df9ac
96b2814
64df9ac
96b2814
64df9ac
 
96b2814
64df9ac
 
 
 
 
 
 
 
 
 
 
 
de53991
1921336
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#from utils.multiple_stream import create_interface
import random
import gradio as gr
import json

from utils.data import dataset
from utils.multiple_stream import stream_data
from pages.summarization_playground import get_model_batch_generation
from pages.summarization_playground import custom_css

global global_selected_choice

def random_data_selection():
    datapoint = random.choice(dataset)
    datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']

    return datapoint

# Function to handle user selection and disable the radio
def lock_selection(selected_option):
    global global_selected_choice
    global_selected_choice = selected_option  # Store the selected choice in the variable
    return gr.update(visible=True), selected_option, gr.update(interactive=False), gr.update(interactive=False)

def create_arena():
    with open("prompt/prompt.json", "r") as file:
        json_data = file.read()
        prompts = json.loads(json_data)

    with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
        with gr.Group():
            datapoint = random_data_selection()
            gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.

Once the streaming is complete, you can choose the best response.\u2764\ufe0f""")

            data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
            with gr.Row():
                random_selection_button = gr.Button("Change Data")
                submit_button = gr.Button("✨ Click to Streaming ✨")

            random_selection_button.click(
                fn=random_data_selection,
                inputs=[],
                outputs=[data_textbox]
            )

            random.shuffle(prompts)
            random_selected_prompts = prompts[:3]
    
            with gr.Row():
                columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
            
            content_list = [prompt['prompt'] + '\n{' + data_textbox.value + '}\n\nsummary:' for prompt in random_selected_prompts]
            model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")

            def start_streaming():
                for data in stream_data(content_list, model):
                    updates = [gr.update(value=data[i]) for i in range(len(columns))]
                    yield tuple(updates)
            
            submit_button.click(
                fn=start_streaming,
                inputs=[],
                outputs=columns,
                show_progress=False
            )

            choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
    
            submit_button = gr.Button("Submit")

            # Output to display the selected option
            output = gr.Textbox(label="You selected:", visible=False)

            submit_button.click(fn=lock_selection, inputs=choice, outputs=[output, output, choice, submit_button])

            global global_selected_choice
            if global_selected_choice == "Response 1":
                prompt_id = random_selected_prompts[0]
            elif global_selected_choice == "Response 2":
                prompt_id = random_selected_prompts[1]
            elif global_selected_choice == "Response 3":
                prompt_id = random_selected_prompts[2]
            else:
                raise ValueError(f"No corresponding response of {global_selected_choice}")
            
            for i in range(len(prompts)):
                if prompts[i]['id'] == prompt_id:
                    prompts[i]["metric"]["winning_number"] += 1
                    break

                if i == len(prompts)-1:
                    raise ValueError(f"No prompt of id {prompt_id}")

            with open("prompt/prompt.json", "w") as f:
                json.dump(prompts, f)

    return demo

if __name__ == "__main__":
    demo = create_arena()
    demo.queue()
    demo.launch()