Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
class TwitterEmotionClassifier: | |
def __init__(self, model_name: str, model_type: str): | |
self.is_gpu = False | |
self.model_type = model_type | |
device = torch.device("cuda") if self.is_gpu else torch.device("cpu") | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
self.bertweet = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
self.deberta = None | |
self.emotions = { | |
"LABEL_0": "sadness", | |
"LABEL_1": "joy", | |
"LABEL_2": "love", | |
"LABEL_3": "anger", | |
"LABEL_4": "fear", | |
"LABEL_5": "surprise", | |
} | |
def get_model(self, model_type: str): | |
if self.model_type == "bertweet" and model_type == self.model_type: | |
return self.bertweet | |
elif model_type == "deberta": | |
if self.deberta: | |
return self.deberta | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
self.deberta = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
return self.deberta | |
def predict(self, twitter: str, model_type: str): | |
classifier = self.get_model(model_type) | |
preds = classifier(twitter, return_all_scores=True) | |
if preds: | |
pred = preds[0] | |
res = { | |
"Sadness ๐ข": pred[0]["score"], | |
"Joy ๐": pred[1]["score"], | |
"Love ๐": pred[2]["score"], | |
"Anger ๐ ": pred[3]["score"], | |
"Fear ๐ฑ": pred[4]["score"], | |
"Surprise ๐ฎ": pred[5]["score"], | |
} | |
return res | |
return None | |
def main(): | |
model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet") | |
interFace = gr.Interface( | |
fn=model.predict, | |
inputs=[ | |
gr.inputs.Textbox( | |
placeholder="What's happenning?", label="Tweet content", lines=5 | |
), | |
gr.inputs.Radio(["bertweet", "deberta"], label="Model"), | |
], | |
outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "), | |
verbose=True, | |
examples=[ | |
["This GOT show just remember LOTR times!", "bertweet"], | |
["Man, that my 30 days of training just got a NaN loss!!!", "bertweet"], | |
["I couldn't see 3 Tom Hollands coming...", "bertweet"], | |
[ | |
"There is nothing better than a soul-warming coffee in the morning", | |
"bertweet", | |
], | |
["I fear the vanishing gradient a lot", "deberta"], | |
], | |
title="Emotion classification with DeBERTa-v3 ๐ค", | |
description="", | |
theme="huggingface", | |
) | |
interFace.launch() | |
if __name__ == "__main__": | |
main() | |