explorewithai's picture
Create app.py
1b442a6 verified
raw
history blame
2.25 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch # Import torch
# Load the model and tokenizer (same as your original code)
model_name = "frameai/PersianSentiment"
loaded_tokenizer = AutoTokenizer.from_pretrained(model_name)
loaded_model = AutoModelForSequenceClassification.from_pretrained(model_name)
def predict_sentiment(text):
"""Predicts the sentiment of a given text."""
inputs = loaded_tokenizer(text, return_tensors="pt", padding=True, truncation=True) # Add padding and truncation
outputs = loaded_model(**inputs)
# Use softmax to get probabilities and argmax to get the predicted class
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probabilities, dim=-1).item()
if predictions == 0:
sentiment = "Negative"
elif predictions == 1:
sentiment = "Positive"
else:
sentiment = "Neutral"
# Return probabilities as well for a more informative output
return {
"Negative": float(probabilities[0][0]),
"Positive": float(probabilities[0][1]),
"Neutral": float(probabilities[0][2]),
}, sentiment
# Create example sentences
examples = [
["این فیلم عالی بود!"], # Positive example
["من این غذا را دوست نداشتم."], # Negative example
["هوا خوب است."], # Neutral (could be slightly positive, depends on context)
["کتاب جالبی بود اما کمی خسته کننده هم بود."] , # Mixed/Neutral
["اصلا راضی نبودم."] #negative
]
# Create the Gradio interface
iface = gr.Interface(
fn=predict_sentiment,
inputs=gr.Textbox(label="Enter Persian Text", lines=5, placeholder="Type your text here..."),
outputs=[
gr.Label(label="Sentiment Probabilities"),
gr.Textbox(label="Predicted Sentiment") # Add output component for the sentiment string
],
title="Persian Sentiment Analysis",
description="Enter a Persian sentence and get its sentiment (Positive, Negative, or Neutral).",
examples=examples,
live=False # set to True for automatic updates as you type
)
if __name__ == "__main__":
iface.launch()