#from utils.multiple_stream import create_interface import random import gradio as gr import json import logging import gc import torch from utils.data import dataset from utils.multiple_stream import stream_data from pages.summarization_playground import get_model_batch_generation def create_arena(): with open("prompt/prompt.json", "r") as file: json_data = file.read() prompts = json.loads(json_data) with gr.Blocks() as demo: with gr.Group(): datapoint = random.choice(dataset) datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'] submit_button = gr.Button("✨ Submit ✨") with gr.Row(): columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(prompts))] content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts] model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct") def start_streaming(): for data in stream_data(content_list, model): updates = [gr.update(value=data[i]) for i in range(len(columns))] yield tuple(updates) submit_button.click( fn=start_streaming, inputs=[], outputs=columns, show_progress=False ) return demo if __name__ == "__main__": demo = create_arena() demo.queue() demo.launch()