|
import gradio as gr |
|
from transformers import pipeline |
|
from utils import * |
|
|
|
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True) |
|
|
|
max_textboxes = 100 |
|
|
|
|
|
|
|
def greet(opinion): |
|
chunks = chunk_data(remove_citations(opinion))["text"].to_list() |
|
result = average_text(chunks, pipe) |
|
k = len(chunks) |
|
wrt_boxes = [] |
|
for i in range(k): |
|
wrt_boxes.append(gr.Textbox(chunks[i], visible=True)) |
|
wrt_boxes.append(gr.Label(value=result[1][i], visible=True)) |
|
return ( |
|
[result[0]] |
|
+ wrt_boxes |
|
+ [gr.Textbox(visible=False), gr.Label(visible=False)] * (max_textboxes - k) |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
opinion = gr.Textbox(label="Opinion") |
|
op_level = gr.outputs.Label(num_top_classes=13, label="Overall") |
|
greet_btn = gr.Button("Predict") |
|
textboxes = [] |
|
for i in range(max_textboxes): |
|
t = gr.Textbox(f"Textbox {i}", visible=False, label=f"Paragraph {i+1} Text") |
|
par_level = gr.Label( |
|
num_top_classes=5, label=f"Paragraph {i+1} Prediction", visible=False |
|
) |
|
textboxes.append(t) |
|
textboxes.append(par_level) |
|
|
|
greet_btn.click( |
|
fn=greet, |
|
inputs=opinion, |
|
outputs=[op_level] + textboxes, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|