|
import gradio as gr |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
import torch |
|
|
|
description = "Sentiment Analysis :) && :(" |
|
title = "SentBERT" |
|
examples = [["That ice cream was really bad"], ["Great to meet you!"], ["Hey, there's a snake there"]] |
|
|
|
class2interpret = { |
|
0: 'Positive/Neutral', |
|
1: 'Negative' |
|
} |
|
|
|
def classify(example): |
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") |
|
inputs = tokenizer(example, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
probs = torch.nn.Softmax(dim=1)(logits).tolist()[0] |
|
|
|
return {class2interpret[0]: probs[0], class2interpret[1]: probs[1]}, {class2interpret[0]: probs[0], class2interpret[1]: probs[1]} |
|
|
|
interface = gr.Interface(fn=classify, inputs='text', outputs=['label', 'json'], examples=examples, description=description, title=title) |
|
interface.launch() |