|
import gradio as gr |
|
import torch |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
|
|
|
model_name = "bhadresh-savani/distilbert-base-uncased-finetuned-sentiment" |
|
tokenizer = DistilBertTokenizer.from_pretrained(model_name) |
|
model = DistilBertForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
|
|
def predict_sentiment(text): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) |
|
outputs = model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
return predictions.item() |
|
|
|
|
|
with gr.Blocks() as sentiment_app: |
|
gr.Markdown("<h1>Sentiment Analysis with DistilBERT</h1>") |
|
|
|
input_box = gr.Textbox(label="Input Text", placeholder="Enter text to analyze sentiment") |
|
output_box = gr.Textbox(label="Sentiment Result", placeholder="Sentiment result will appear here") |
|
|
|
submit_button = gr.Button("Analyze Sentiment") |
|
|
|
|
|
submit_button.click(fn=predict_sentiment, inputs=input_box, outputs=output_box) |
|
|
|
|
|
if __name__ == "__main__": |
|
sentiment_app.launch() |
|
|