Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from utils import * | |
from datasets import load_dataset | |
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True) | |
all = load_dataset("raminass/full_opinions_1994_2020") | |
df = pd.DataFrame(all["train"]) | |
choices = [] | |
for index, row in df[df.category == "per_curiam"].iterrows(): | |
if len(row["text"]) > 1000: | |
choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]])) | |
max_textboxes = 100 | |
# https://www.gradio.app/guides/controlling-layout | |
def greet(opinion, year): | |
judges_l = ( | |
df[(df["year_filed"] == year) & (df["category"] != "per_curiam")] | |
.author_name.unique() | |
.tolist() | |
) | |
chunks = chunk_data(remove_citations(opinion))["text"].to_list() | |
result = average_text(chunks, pipe, judges_l) | |
k = len(chunks) | |
wrt_boxes = [] | |
for i in range(k): | |
wrt_boxes.append(gr.Textbox(chunks[i], visible=True)) | |
wrt_boxes.append(gr.Label(value=result[1][i], visible=True)) | |
return ( | |
[result[0]] | |
+ wrt_boxes | |
+ [gr.Textbox(visible=False), gr.Label(visible=False)] * (max_textboxes - k) | |
) | |
def set_input(drop): | |
return drop[0], drop[1], gr.Slider(visible=False) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
opinion = gr.Textbox(label="Opinion") | |
year = gr.Slider(1994, 2020, step=1, label="Year") | |
drop = gr.Dropdown(choices=sorted(choices)) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
greet_btn = gr.Button("Predict") | |
op_level = gr.outputs.Label(num_top_classes=13, label="Overall") | |
textboxes = [] | |
for i in range(max_textboxes): | |
with gr.Row(): | |
t = gr.Textbox(f"Textbox {i}", visible=False, label=f"Paragraph {i+1} Text") | |
par_level = gr.Label( | |
num_top_classes=5, label=f"Paragraph {i+1} Prediction", visible=False | |
) | |
textboxes.append(t) | |
textboxes.append(par_level) | |
drop.select(set_input, inputs=drop, outputs=[opinion, year, year]) | |
greet_btn.click( | |
fn=greet, | |
inputs=[opinion, year], | |
outputs=[op_level] + textboxes, | |
) | |
clear_btn.click( | |
fn=lambda: [None, 1994, gr.Slider(visible=True), None, None] | |
+ [gr.Textbox(visible=False), gr.Label(visible=False)] * max_textboxes, | |
outputs=[opinion, year, year, drop, op_level] + textboxes, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |