File size: 1,783 Bytes
db9e12e
fb39a61
 
 
 
 
 
 
db9e12e
44be0d1
 
fb39a61
 
 
 
 
44be0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load the model and tokenizer once during initialization
model_name = "AnkitAI/deberta-xlarge-base-emotions-classifier"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define the function to use the model for predictions
def classify_emotion(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    labels = ["joy", "anger", "sadness", "fear", "surprise", "love"]  # Adjust based on the actual labels used by the model
    return {labels[i]: float(probs[0][i]) for i in range(len(labels))}

# Validate the input
def validate_input(text):
    if len(text.strip()) == 0:
        return "Please enter some text."
    return classify_emotion(text)

# Define the Gradio interface
interface = gr.Interface(
    fn=validate_input,
    inputs=gr.Textbox(lines=5, placeholder="Enter text here...", label="Input Text"),
    outputs=gr.Label(label="Predicted Emotion"),
    title="Emotion Classifier",
    description="Enter some text and let the model predict the emotion.",
    examples=["I am feeling great today!", "I am so sad and depressed.", "I am excited about the new project."],
)

# Add some custom CSS to improve the look and feel
css = """
body {
    background-color: #f8f9fa;
    font-family: Arial, sans-serif;
}
h1 {
    color: #007bff;
}
.gradio-container {
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
"""

# Launch the Gradio app with custom CSS
interface.launch(server_name="0.0.0.0", server_port=8080, inline=False, css=css)