File size: 4,344 Bytes
f80cc50
8ddc567
72f0cff
71645c3
8fbc997
72f0cff
af9fa51
5447c0b
dc881f6
 
 
71645c3
 
2b4f66b
 
 
 
 
926ab2a
ca9892d
f80cc50
8fbc997
 
 
f80cc50
8ddc567
 
9c55aba
8ddc567
fb7fb6c
9c55aba
d4e0539
8ddc567
f80cc50
71645c3
9c55aba
 
 
 
 
8fbc997
 
a4466d4
9c55aba
71645c3
95eb387
 
 
 
 
 
 
 
 
 
71645c3
f80cc50
fb7fb6c
eecd399
9c55aba
 
a4466d4
 
9c55aba
 
 
8fbc997
9c55aba
 
0805275
9c55aba
 
8fbc997
 
f1d21d9
 
9c55aba
 
 
a4466d4
9c55aba
eecd399
926ab2a
 
 
9c55aba
 
fb7fb6c
d4e0539
95eb387
 
 
 
 
8ddc567
9c55aba
 
 
 
 
 
 
 
 
 
926ab2a
71645c3
8ddc567
 
9c55aba
d4e0539
8ddc567
 
926ab2a
9c55aba
 
926ab2a
 
f80cc50
 
33e28cc
747849d
 
33e28cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from transformers import pipeline
from utils import *
from datasets import load_dataset
import json

pipe = pipeline(model="raminass/m4", top_k=17, padding=True, truncation=True)

# all = load_dataset("raminass/full_opinions_1994_2020")
all = load_dataset("raminass/opinions-94-23")

df = pd.DataFrame(all["train"])
choices = []
percuriams = df[df.type == "per_curiam"]
percuriams["case_name"] = percuriams["case_name"].apply(lambda x: x.strip())
percuriams = percuriams.sort_values(by="case_name", key=lambda x: x.str.lower())

for index, row in percuriams.iterrows():
    if len(row["text"]) > 1000:
        choices.append((f"""{row["case_name"]}""", [row["text"], row["year"]]))

with open("j_year.json", "r") as j:
    judges_by_year = json.loads(j.read())
judges_by_year = {int(k): v for k, v in judges_by_year.items()}


# https://www.gradio.app/guides/controlling-layout
def greet(opinion, judges_l):
    chunks = chunk_data(remove_citations(opinion))["text"].to_list()
    result = average_text(chunks, pipe, judges_l)

    return result[0], {k: round(v * 100, 2) for k, v in result[0].items()}


def set_input(drop):
    return drop[0], drop[1], gr.Slider(visible=True)


def update_year(year):
    return gr.CheckboxGroup(
        judges_by_year[year],
        value=judges_by_year[year],
        label="Select Justices",
    )

# Paragraph text
paragraph_text = (
    "One can refine these observations based on the prediction scores obtained in each case. "
    "As explained in the Methods, these scores do not correspond to probabilities but can be calibrated "
    "based on the cross-validation results. In particular, when we trained the algorithm we noticed that if "
    "the top prediction score is greater than 40% (50%, 60%), our accuracy in predicting the authoring justice "
    "increases to 93% (95%, 96%, respectively). Similarly, the original accuracy further improves to 95% when "
    "considering the top two predictions per opinion, rather than a single one. If the sum of the top two prediction "
    "scores exceeds 50%, the accuracy increases to 98%."
)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2):
            drop = gr.Dropdown(
                choices=sorted(choices),
                label="List of Per Curiam Opinions",
                info="Select a per curiam opinion from the dropdown menu and press the Predict Button",
            )
            year = gr.Slider(
                1994,
                2023,
                step=1,
                label="Year",
                info="Select the year of the opinion if you manually paste the opinion below",
            )
            exc_judg = gr.CheckboxGroup(
                judges_by_year[year.value],
                value=judges_by_year[year.value],
                label="Select Justices",
                info="Select justices to consider in prediction",
            )

            opinion = gr.Textbox(
                label="Opinion", info="Paste opinion text here and press the Predict Button"
            )
        with gr.Column(scale=1):
            with gr.Row():
                clear_btn = gr.Button("Clear")
                greet_btn = gr.Button("Predict")
            op_level = gr.outputs.Label(
                num_top_classes=9, label="Predicted author of opinion"
            )
            output_textbox = gr.Textbox(label="Output Text", show_copy_button=True)
            info_textbox = gr.Textbox(
                value=paragraph_text,
                label="Additional Insights",
                interactive=False,  # Makes the textbox read-only
            )

    year.release(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    year.change(
        update_year,
        inputs=[year],
        outputs=[exc_judg],
    )
    drop.select(set_input, inputs=drop, outputs=[opinion, year, year])

    greet_btn.click(
        fn=greet,
        inputs=[opinion, exc_judg],
        outputs=[op_level, output_textbox],
    )

    clear_btn.click(
        fn=lambda: [None, 1994, gr.Slider(visible=True), None, None],
        outputs=[opinion, year, year, drop, op_level],
    )


if __name__ == "__main__":
    demo.launch(
        # auth=("sc2024", "sc2024"),
        # auth_message="To request access, please email [email protected]",
    )