LTP / app.py
sashdev's picture
Update app.py
5d50a44 verified
raw
history blame
1.29 kB
import gradio as gr
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load DistilBERT model and tokenizer
model_name = "bhadresh-savani/distilbert-base-uncased-finetuned-sentiment"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Define the prediction function
def predict_sentiment(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
return predictions.item()
# Gradio interface
with gr.Blocks() as sentiment_app:
gr.Markdown("<h1>Sentiment Analysis with DistilBERT</h1>")
input_box = gr.Textbox(label="Input Text", placeholder="Enter text to analyze sentiment")
output_box = gr.Textbox(label="Sentiment Result", placeholder="Sentiment result will appear here")
submit_button = gr.Button("Analyze Sentiment")
# Button click event
submit_button.click(fn=predict_sentiment, inputs=input_box, outputs=output_box)
# Launch the app
if __name__ == "__main__":
sentiment_app.launch()