Emanuel's picture
First commit
6160ca6
raw
history blame
3.43 kB
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
class TwitterEmotionClassifier:
def __init__(self, model_name: str, model_type: str):
self.is_gpu = False
self.model_type = model_type
device = torch.device("cuda") if self.is_gpu else torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(device)
model.eval()
self.bertweet = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=self.is_gpu - 1,
)
self.deberta = None
self.emotions = {
"LABEL_0": "sadness",
"LABEL_1": "joy",
"LABEL_2": "love",
"LABEL_3": "anger",
"LABEL_4": "fear",
"LABEL_5": "surprise",
}
def get_model(self, model_type: str):
if self.model_type == "bertweet" and model_type == self.model_type:
return self.bertweet
elif model_type == "deberta":
if self.deberta:
return self.deberta
model = AutoModelForSequenceClassification.from_pretrained(
"Emanuel/twitter-emotion-deberta-v3-base"
)
tokenizer = AutoTokenizer.from_pretrained(
"Emanuel/twitter-emotion-deberta-v3-base"
)
self.deberta = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=self.is_gpu - 1,
)
return self.deberta
def predict(self, twitter: str, model_type: str):
classifier = self.get_model(model_type)
preds = classifier(twitter, return_all_scores=True)
if preds:
pred = preds[0]
res = {
"Sadness ๐Ÿ˜ข": pred[0]["score"],
"Joy ๐Ÿ˜‚": pred[1]["score"],
"Love ๐Ÿ’›": pred[2]["score"],
"Anger ๐Ÿ˜ ": pred[3]["score"],
"Fear ๐Ÿ˜ฑ": pred[4]["score"],
"Surprise ๐Ÿ˜ฎ": pred[5]["score"],
}
return res
return None
def main():
model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet")
interFace = gr.Interface(
fn=model.predict,
inputs=[
gr.inputs.Textbox(
placeholder="What's happenning?", label="Tweet content", lines=5
),
gr.inputs.Radio(["bertweet", "deberta"], label="Model"),
],
outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "),
verbose=True,
examples=[
["This GOT show just remember LOTR times!", "bertweet"],
["Man, that my 30 days of training just got a NaN loss!!!", "bertweet"],
["I couldn't see 3 Tom Hollands coming...", "bertweet"],
[
"There is nothing better than a soul-warming coffee in the morning",
"bertweet",
],
["I fear the vanishing gradient a lot", "deberta"],
],
title="Emotion classification with DeBERTa-v3 ๐Ÿค–",
description="",
theme="huggingface",
)
interFace.launch()
if __name__ == "__main__":
main()