import gradio as gr from transformers import pipeline from utils import * from datasets import load_dataset pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True) all = load_dataset("raminass/full_opinions_1994_2020") df = pd.DataFrame(all["train"]) choices = [] for index, row in df[df.category == "per_curiam"].iterrows(): if len(row["text"]) > 1000: choices.append((f"""{row["case_name"]}""", [row["text"], row["year_filed"]])) unique_judges_by_year = ( df[df.author_name != "per_curiam"].groupby("year_filed")["author_name"].unique() ) additional_judges = ["Justice Breyer", "Justice Kennedy"] unique_judges_by_year[1994] = list(unique_judges_by_year[1994]) + additional_judges # https://www.gradio.app/guides/controlling-layout def greet(opinion, judges_l): chunks = chunk_data(remove_citations(opinion))["text"].to_list() result = average_text(chunks, pipe, judges_l) return result[0] def set_input(drop): return drop[0], drop[1], gr.Slider(visible=True) def update_year(year): return gr.CheckboxGroup( unique_judges_by_year[year].tolist(), value=unique_judges_by_year[year].tolist(), label="Select Judges", ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): drop = gr.Dropdown( choices=sorted(choices), label="Per Curiam Opinions", info="Select a per curiam opinion to use as input", ) year = gr.Slider( 1994, 2020, step=1, label="Year", info="Select the year of the opinion if you manually pass the opinion below", ) exc_judg = gr.CheckboxGroup( unique_judges_by_year[year.value], value=unique_judges_by_year[year.value], label="Select Judges", info="Select judges to consider in prediction", ) opinion = gr.Textbox( label="Opinion", info="Paste opinion text here or select from dropdown" ) with gr.Column(): with gr.Row(): clear_btn = gr.Button("Clear") greet_btn = gr.Button("Predict") op_level = gr.outputs.Label( num_top_classes=9, label="Predicted author of opinion" ) year.release( update_year, inputs=[year], outputs=[exc_judg], ) year.change( update_year, inputs=[year], outputs=[exc_judg], ) drop.select(set_input, inputs=drop, outputs=[opinion, year, year]) greet_btn.click( fn=greet, inputs=[opinion, exc_judg], outputs=[op_level], ) clear_btn.click( fn=lambda: [None, 1994, gr.Slider(visible=True), None, None], outputs=[opinion, year, year, drop, op_level], ) if __name__ == "__main__": demo.launch(debug=True)