Spaces:
Running
Running
File size: 4,344 Bytes
f80cc50 8ddc567 72f0cff 71645c3 8fbc997 72f0cff af9fa51 5447c0b dc881f6 71645c3 2b4f66b 926ab2a ca9892d f80cc50 8fbc997 f80cc50 8ddc567 9c55aba 8ddc567 fb7fb6c 9c55aba d4e0539 8ddc567 f80cc50 71645c3 9c55aba 8fbc997 a4466d4 9c55aba 71645c3 95eb387 71645c3 f80cc50 fb7fb6c eecd399 9c55aba a4466d4 9c55aba 8fbc997 9c55aba 0805275 9c55aba 8fbc997 f1d21d9 9c55aba a4466d4 9c55aba eecd399 926ab2a 9c55aba fb7fb6c d4e0539 95eb387 8ddc567 9c55aba 926ab2a 71645c3 8ddc567 9c55aba d4e0539 8ddc567 926ab2a 9c55aba 926ab2a f80cc50 33e28cc 747849d 33e28cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset
import json
pipe = pipeline(model="raminass/m4", top_k=17, padding=True, truncation=True)
# all = load_dataset("raminass/full_opinions_1994_2020")
all = load_dataset("raminass/opinions-94-23")
df = pd.DataFrame(all["train"])
choices = []
percuriams = df[df.type == "per_curiam"]
percuriams["case_name"] = percuriams["case_name"].apply(lambda x: x.strip())
percuriams = percuriams.sort_values(by="case_name", key=lambda x: x.str.lower())
for index, row in percuriams.iterrows():
if len(row["text"]) > 1000:
choices.append((f"""{row["case_name"]}""", [row["text"], row["year"]]))
with open("j_year.json", "r") as j:
judges_by_year = json.loads(j.read())
judges_by_year = {int(k): v for k, v in judges_by_year.items()}
# https://www.gradio.app/guides/controlling-layout
def greet(opinion, judges_l):
chunks = chunk_data(remove_citations(opinion))["text"].to_list()
result = average_text(chunks, pipe, judges_l)
return result[0], {k: round(v * 100, 2) for k, v in result[0].items()}
def set_input(drop):
return drop[0], drop[1], gr.Slider(visible=True)
def update_year(year):
return gr.CheckboxGroup(
judges_by_year[year],
value=judges_by_year[year],
label="Select Justices",
)
# Paragraph text
paragraph_text = (
"One can refine these observations based on the prediction scores obtained in each case. "
"As explained in the Methods, these scores do not correspond to probabilities but can be calibrated "
"based on the cross-validation results. In particular, when we trained the algorithm we noticed that if "
"the top prediction score is greater than 40% (50%, 60%), our accuracy in predicting the authoring justice "
"increases to 93% (95%, 96%, respectively). Similarly, the original accuracy further improves to 95% when "
"considering the top two predictions per opinion, rather than a single one. If the sum of the top two prediction "
"scores exceeds 50%, the accuracy increases to 98%."
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
drop = gr.Dropdown(
choices=sorted(choices),
label="List of Per Curiam Opinions",
info="Select a per curiam opinion from the dropdown menu and press the Predict Button",
)
year = gr.Slider(
1994,
2023,
step=1,
label="Year",
info="Select the year of the opinion if you manually paste the opinion below",
)
exc_judg = gr.CheckboxGroup(
judges_by_year[year.value],
value=judges_by_year[year.value],
label="Select Justices",
info="Select justices to consider in prediction",
)
opinion = gr.Textbox(
label="Opinion", info="Paste opinion text here and press the Predict Button"
)
with gr.Column(scale=1):
with gr.Row():
clear_btn = gr.Button("Clear")
greet_btn = gr.Button("Predict")
op_level = gr.outputs.Label(
num_top_classes=9, label="Predicted author of opinion"
)
output_textbox = gr.Textbox(label="Output Text", show_copy_button=True)
info_textbox = gr.Textbox(
value=paragraph_text,
label="Additional Insights",
interactive=False, # Makes the textbox read-only
)
year.release(
update_year,
inputs=[year],
outputs=[exc_judg],
)
year.change(
update_year,
inputs=[year],
outputs=[exc_judg],
)
drop.select(set_input, inputs=drop, outputs=[opinion, year, year])
greet_btn.click(
fn=greet,
inputs=[opinion, exc_judg],
outputs=[op_level, output_textbox],
)
clear_btn.click(
fn=lambda: [None, 1994, gr.Slider(visible=True), None, None],
outputs=[opinion, year, year, drop, op_level],
)
if __name__ == "__main__":
demo.launch(
# auth=("sc2024", "sc2024"),
# auth_message="To request access, please email [email protected]",
)
|