Depreesion / llm /mentalBERT.py
vitorcalvi's picture
pre-launch
fc286f6
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import gradio as gr
# Load the tokenizer and models
tokenizer = RobertaTokenizer.from_pretrained("mental/mental-roberta-base")
sentiment_model = RobertaForSequenceClassification.from_pretrained("mental/mental-roberta-base")
emotion_model = RobertaForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
# Define the labels
sentiment_labels = ["negative", "positive"]
emotion_labels = ["anger", "disgust", "fear", "joy", "neutral", "sadness", "surprise"]
def analyze_text(text):
try:
# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Get sentiment model outputs
sentiment_outputs = sentiment_model(**inputs)
sentiment_logits = sentiment_outputs.logits
sentiment_probs = torch.nn.functional.softmax(sentiment_logits, dim=-1)
# Debugging: Print logits and probs shapes
print("Sentiment logits shape:", sentiment_logits.shape)
print("Sentiment logits:", sentiment_logits)
print("Sentiment probs shape:", sentiment_probs.shape)
print("Sentiment probs:", sentiment_probs)
# Get the highest probability and corresponding label for sentiment
max_sentiment_prob, max_sentiment_index = torch.max(sentiment_probs, dim=1)
sentiment = sentiment_labels[max_sentiment_index.item()]
# Get emotion model outputs
emotion_outputs = emotion_model(**inputs)
emotion_logits = emotion_outputs.logits
emotion_probs = torch.nn.functional.softmax(emotion_logits, dim=-1)
# Debugging: Print logits and probs shapes
print("Emotion logits shape:", emotion_logits.shape)
print("Emotion logits:", emotion_logits)
print("Emotion probs shape:", emotion_probs.shape)
print("Emotion probs:", emotion_probs)
# Get the highest probability and corresponding label for emotion
max_emotion_prob, max_emotion_index = torch.max(emotion_probs, dim=1)
emotion = emotion_labels[max_emotion_index.item()]
return sentiment, f"{max_sentiment_prob.item():.4f}", emotion, f"{max_emotion_prob.item():.4f}"
except Exception as e:
print("Error:", str(e))
return "Error", "N/A", "Error", "N/A"
# Define the Gradio interface
interface = gr.Interface(
fn=analyze_text,
inputs=gr.Textbox(
lines=5,
placeholder="Enter text here...",
value="I don’t know a lot but what I do know is, we don’t start off very big and we all try to make each other smaller."
),
outputs=[
gr.Textbox(label="Detected Sentiment"),
gr.Textbox(label="Sentiment Confidence Score"),
gr.Textbox(label="Detected Emotion"),
gr.Textbox(label="Emotion Confidence Score")
],
title="Sentiment and Emotion Analysis: Detecting Positive/Negative Sentiment and Specific Emotions",
description="Enter a piece of text to detect overall sentiment (positive or negative) and specific emotions (anger, disgust, fear, joy, neutral, sadness, surprise)."
)
# Launch the interface
interface.launch()