Classifier / app.py
Taranosaurus's picture
Added faster model switching and truncation to prevent errors on long inputs
a5709b4
raw
history blame
2.74 kB
from transformers import pipeline, AutoTokenizer
import gradio as gr
import torch
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
summary_checkpoint = "facebook/bart-large-cnn" #"google/pegasus-large"
oracle_checkpoint = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(summary_checkpoint, device=device)
summary = pipeline(task="summarization", model=summary_checkpoint, tokenizer=tokenizer, device=device)
oracle = pipeline(task="zero-shot-classification", model=oracle_checkpoint, device=device)
labels = ["merge","revert","fix","feature","update","refactor","test","security","documentation","style"]
selected_labels = ["feature","update","refactor","test","security","documentation","style"]
def do_the_thing(input, labels):
#print(labels)
summarisation = summary(input, truncation=True)[0]['summary_text']
zsc_results = oracle(sequences=[input, summarisation], candidate_labels=labels, multi_label=False, batch_size=2)
classifications_input = {}
for i in range(len(labels)):
classifications_input.update({zsc_results[0]['labels'][i]: zsc_results[0]['scores'][i]})
i+=1
#zsc_results_summary = oracle(sequences=summarisation, candidate_labels=labels, multi_label=False)
classifications_summary = {}
for i in range(len(labels)):
classifications_summary.update({zsc_results[1]['labels'][i]: zsc_results[1]['scores'][i]})
i+=1
return [summarisation, classifications_input, classifications_summary]
with gr.Blocks() as frontend:
gr.Markdown(f"## Git Commit Classifier\n\nThis tool is to take the notes from a commit, summarise and classify the original and the summary.\n\nTo get the git commit notes, clone the repo and the run `git log --all --pretty='format:Subject: %s%nBody: %b%n-----%n'`")
input_value = gr.TextArea(label="Notes to Summarise")
btn_submit = gr.Button(value="Summarise and Classify")
with gr.Row():
with gr.Column():
input_labels = gr.Dropdown(label="Classification Labels", choices=labels, multiselect=True, value=selected_labels, interactive=True, allow_custom_value=True, info="Labels to classify the original text and summary")
with gr.Column():
output_summary_text = gr.TextArea(label="Summary of Notes")
with gr.Row():
with gr.Column():
output_original_labels = gr.Label(label="Original Text Classification")
with gr.Column():
output_summary_labels = gr.Label(label="Summary Text Classification")
btn_submit.click(fn=do_the_thing, inputs=[input_value, input_labels], outputs=[output_summary_text, output_original_labels, output_summary_labels])
frontend.launch()