Spaces:
Runtime error
Runtime error
File size: 3,429 Bytes
6160ca6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
class TwitterEmotionClassifier:
def __init__(self, model_name: str, model_type: str):
self.is_gpu = False
self.model_type = model_type
device = torch.device("cuda") if self.is_gpu else torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(device)
model.eval()
self.bertweet = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=self.is_gpu - 1,
)
self.deberta = None
self.emotions = {
"LABEL_0": "sadness",
"LABEL_1": "joy",
"LABEL_2": "love",
"LABEL_3": "anger",
"LABEL_4": "fear",
"LABEL_5": "surprise",
}
def get_model(self, model_type: str):
if self.model_type == "bertweet" and model_type == self.model_type:
return self.bertweet
elif model_type == "deberta":
if self.deberta:
return self.deberta
model = AutoModelForSequenceClassification.from_pretrained(
"Emanuel/twitter-emotion-deberta-v3-base"
)
tokenizer = AutoTokenizer.from_pretrained(
"Emanuel/twitter-emotion-deberta-v3-base"
)
self.deberta = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=self.is_gpu - 1,
)
return self.deberta
def predict(self, twitter: str, model_type: str):
classifier = self.get_model(model_type)
preds = classifier(twitter, return_all_scores=True)
if preds:
pred = preds[0]
res = {
"Sadness ๐ข": pred[0]["score"],
"Joy ๐": pred[1]["score"],
"Love ๐": pred[2]["score"],
"Anger ๐ ": pred[3]["score"],
"Fear ๐ฑ": pred[4]["score"],
"Surprise ๐ฎ": pred[5]["score"],
}
return res
return None
def main():
model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet")
interFace = gr.Interface(
fn=model.predict,
inputs=[
gr.inputs.Textbox(
placeholder="What's happenning?", label="Tweet content", lines=5
),
gr.inputs.Radio(["bertweet", "deberta"], label="Model"),
],
outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "),
verbose=True,
examples=[
["This GOT show just remember LOTR times!", "bertweet"],
["Man, that my 30 days of training just got a NaN loss!!!", "bertweet"],
["I couldn't see 3 Tom Hollands coming...", "bertweet"],
[
"There is nothing better than a soul-warming coffee in the morning",
"bertweet",
],
["I fear the vanishing gradient a lot", "deberta"],
],
title="Emotion classification with DeBERTa-v3 ๐ค",
description="",
theme="huggingface",
)
interFace.launch()
if __name__ == "__main__":
main()
|