File size: 1,319 Bytes
f80cc50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
from transformers import pipeline, TextClassificationPipeline
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True)

def average_text(text, model):
  # result = classifier(df_train[(df_train.case_name==case) & (df_train.category=='per_curiam')]['clean_text'].to_list())
  result = model(text)
  pred = {}
  for c in result:
    for d in c:
      if d['label'] not in pred:
        pred[d['label']] = [round(d['score'],2)]
      else:
        pred[d['label']].append(round(d['score'],2))
  sumary = {k:round(sum(v)/len(v),2) for k,v in pred.items()}
  result = [[{k: round(v, 2) if k=='score' else v for k, v in dct.items()} for dct in lst ]  for lst in result]
  return dict(sorted(sumary.items(), key=lambda x: x[1],reverse=True)), result

def greet(opinion):
    result = average_text(chunk_data(remove_citations(opinion))['text'].to_list(),pipe)
    # print(f"average prediction:")
    # display(result[0])
    # print(f"paragraph prediction:")
    # display(result[1])
    return result[0]

with gr.Blocks() as demo:
    opinion = gr.Textbox(label="Opinion")
    output = gr.Textbox(label="Result")
    greet_btn = gr.Button("Predict")
    greet_btn.click(fn=greet, inputs=opinion, outputs=output, api_name="SCOTUS")

if __name__ == "__main__":
    demo.launch()