import streamlit as st import pandas as pd import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # Define the available models to choose from models = { 'BERT': 'bert-base-uncased', 'RoBERTa': 'roberta-base', 'DistilBERT': 'distilbert-base-uncased' } # Create a drop-down menu to select the model model_name = st.sidebar.selectbox('Select Model', list(models.keys())) # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained(models[model_name]) model = AutoModelForSequenceClassification.from_pretrained(models[model_name]) # Define the classes and their corresponding labels classes = { 0: 'Non-Toxic', 1: 'Toxic', 2: 'Severely Toxic', 3: 'Obscene', 4: 'Threat', 5: 'Insult', 6: 'Identity Hate' } # Create a function to generate the toxicity predictions @st.cache(allow_output_mutation=True) def predict_toxicity(tweet, model, tokenizer): # Preprocess the text inputs = tokenizer(tweet, padding=True, truncation=True, return_tensors='pt') # Get the predictions from the model outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=1).detach().numpy() # Get the class with the highest probability predicted_class = int(predictions.argmax()) predicted_class_label = classes[predicted_class] predicted_prob = predictions[0][predicted_class] return predicted_class_label, predicted_prob # Create a table to display the toxicity predictions def create_table(predictions): data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []} for tweet, prediction in predictions.items(): data['Tweet'].append(tweet) data['Highest Toxicity Class'].append(prediction[0]) data['Probability'].append(prediction[1]) df = pd.DataFrame(data) return df # Create the user interface st.title('Toxicity Prediction App') tweet_input = st.text_input('Enter a tweet:') if st.button('Predict'): # Generate the toxicity prediction for the tweet using the selected model predicted_class_label, predicted_prob = predict_toxicity(tweet_input, model, tokenizer) prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})' st.write(prediction_text) # Display the toxicity predictions in a table predictions = {tweet_input: (predicted_class_label, predicted_prob)} table = create_table(predictions) st.table(table)