Spaces:
Sleeping
Sleeping
File size: 4,190 Bytes
1921336 42c830b f6590f0 9a1ab03 a797164 488c5c4 3f9babb 96b2814 3f9babb de53991 34ffea3 f6590f0 1921336 488c5c4 4ea28ea 488c5c4 ed67a17 488c5c4 fa738bd 488c5c4 c0a9946 64df9ac 1921336 64df9ac 1921336 64df9ac 1921336 9a1ab03 f664ce2 3f9babb f664ce2 9e26af4 96b2814 64df9ac 96b2814 64df9ac 96b2814 64df9ac 96b2814 64df9ac de53991 1921336 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
#from utils.multiple_stream import create_interface
import random
import gradio as gr
import json
from utils.data import dataset
from utils.multiple_stream import stream_data
from pages.summarization_playground import get_model_batch_generation
from pages.summarization_playground import custom_css
global global_selected_choice
def random_data_selection():
datapoint = random.choice(dataset)
datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
return datapoint
# Function to handle user selection and disable the radio
def lock_selection(selected_option):
global global_selected_choice
global_selected_choice = selected_option # Store the selected choice in the variable
return gr.update(visible=True), selected_option, gr.update(interactive=False), gr.update(interactive=False)
def create_arena():
with open("prompt/prompt.json", "r") as file:
json_data = file.read()
prompts = json.loads(json_data)
with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
with gr.Group():
datapoint = random_data_selection()
gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
Once the streaming is complete, you can choose the best response.\u2764\ufe0f""")
data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
with gr.Row():
random_selection_button = gr.Button("Change Data")
submit_button = gr.Button("✨ Click to Streaming ✨")
random_selection_button.click(
fn=random_data_selection,
inputs=[],
outputs=[data_textbox]
)
random.shuffle(prompts)
random_selected_prompts = prompts[:3]
with gr.Row():
columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
content_list = [prompt['prompt'] + '\n{' + data_textbox.value + '}\n\nsummary:' for prompt in random_selected_prompts]
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
def start_streaming():
for data in stream_data(content_list, model):
updates = [gr.update(value=data[i]) for i in range(len(columns))]
yield tuple(updates)
submit_button.click(
fn=start_streaming,
inputs=[],
outputs=columns,
show_progress=False
)
choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
submit_button = gr.Button("Submit")
# Output to display the selected option
output = gr.Textbox(label="You selected:", visible=False)
submit_button.click(fn=lock_selection, inputs=choice, outputs=[output, output, choice, submit_button])
global global_selected_choice
if global_selected_choice == "Response 1":
prompt_id = random_selected_prompts[0]
elif global_selected_choice == "Response 2":
prompt_id = random_selected_prompts[1]
elif global_selected_choice == "Response 3":
prompt_id = random_selected_prompts[2]
else:
raise ValueError(f"No corresponding response of {global_selected_choice}")
for i in range(len(prompts)):
if prompts[i]['id'] == prompt_id:
prompts[i]["metric"]["winning_number"] += 1
break
if i == len(prompts)-1:
raise ValueError(f"No prompt of id {prompt_id}")
with open("prompt/prompt.json", "w") as f:
json.dump(prompts, f)
return demo
if __name__ == "__main__":
demo = create_arena()
demo.queue()
demo.launch()
|