|
import gradio as gr |
|
from transformers import pipeline, TextClassificationPipeline |
|
pipe = pipeline(model="raminass/scotus-v10", top_k=13, padding=True, truncation=True) |
|
|
|
def average_text(text, model): |
|
|
|
result = model(text) |
|
pred = {} |
|
for c in result: |
|
for d in c: |
|
if d['label'] not in pred: |
|
pred[d['label']] = [round(d['score'],2)] |
|
else: |
|
pred[d['label']].append(round(d['score'],2)) |
|
sumary = {k:round(sum(v)/len(v),2) for k,v in pred.items()} |
|
result = [[{k: round(v, 2) if k=='score' else v for k, v in dct.items()} for dct in lst ] for lst in result] |
|
return dict(sorted(sumary.items(), key=lambda x: x[1],reverse=True)), result |
|
|
|
def greet(opinion): |
|
result = average_text(chunk_data(remove_citations(opinion))['text'].to_list(),pipe) |
|
|
|
|
|
|
|
|
|
return result[0] |
|
|
|
with gr.Blocks() as demo: |
|
opinion = gr.Textbox(label="Opinion") |
|
output = gr.Textbox(label="Result") |
|
greet_btn = gr.Button("Predict") |
|
greet_btn.click(fn=greet, inputs=opinion, outputs=output, api_name="SCOTUS") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|