File size: 1,179 Bytes
84fb171
 
71c9e6d
 
84fb171
 
 
71c9e6d
84fb171
 
 
 
 
 
 
 
 
 
 
71c9e6d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# from transformers import pipeline
from transformers import BartForSequenceClassification, BartTokenizer
import gradio as grad

# zero_shot_classifier = pipeline("zero-shot-classification")
bart_tkn = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
mdl = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

# def classify(text, labels):
def classify(text, label):
    # classifier_labels = labels.split(",")
    # #["software", "politics", "love", "movies", "emergency", "advertisment", "sports"]
    # response = zero_shot_classifier(text, classifier_labels)

    tkn_ids = bart_tkn.encode(text, label, return_tensors = "pt")
    tkn_lgts = mdl(tkn_ids)[0]
    entail_contra_tkn_lgts = tkn_lgts[:, [0, 2]]
    probab = entail_contra_tkn_lgts.softmax(dim = 1)
    response = probab[:, 1].item() * 100

    return response

txt = grad.Textbox(lines = 1, label = "English", placeholder = "text to be classified")
labels = grad.Textbox(lines = 1, label = "Labels", placeholder = "comma separated labels")
out = grad.Textbox(lines = 1, label = "Classification")

grad.Interface(
    classify,
    inputs = [txt, labels],
    outputs = out
).launch()