APJ23's picture
Update app.py
3c380c2
raw
history blame
2.2 kB
import streamlit as st
import pandas as pd
import torch
import asyncio
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model", local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
classes = {
0: 'Non-Toxic',
1: 'Toxic',
2: 'Severely Toxic',
3: 'Obscene',
4: 'Threat',
5: 'Insult',
6: 'Identity Hate'
}
@st.cache(allow_output_mutation=True)
def predict_toxicity(tweet, model, tokenizer):
inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1)
predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
return classes[predicted_class], predicted_prob
def create_table(predictions):
data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
for tweet, prediction in predictions.items():
data['Tweet'].append(tweet)
data['Highest Toxicity Class'].append(prediction[0])
data['Probability'].append(prediction[1])
df = pd.DataFrame(data)
return df
async def run_async_prediction(tweet, model, tokenizer):
loop = asyncio.get_event_loop()
prediction = await loop.run_in_executor(None, predict_toxicity, tweet, model, tokenizer)
return prediction
st.title('Toxicity Prediction App')
tweet_input = st.text_input('Enter a tweet to check for toxicity')
if st.button('Predict'):
predictions = {tweet_input: None}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
prediction = loop.run_until_complete(run_async_prediction(tweet_input, model, tokenizer))
predictions[tweet_input] = prediction
loop.close()
predicted_class_label, predicted_prob = prediction
prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
st.write(prediction_text)
table = create_table(predictions)
st.table(table)