HyperX-Sentience's picture
Update app.py
696726b verified
raw
history blame
2.19 kB
import os
os.system("pip install torch transformers gradio matplotlib")
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
torch.set_num_threads(torch.get_num_threads())
# Load the trained model and tokenizer from Hugging Face Hub
model_path = "HyperX-Sentience/RogueBERT-Toxicity-85K"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Move the model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define toxicity labels
labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
def predict_toxicity(comment):
"""Predicts the toxicity levels of a given comment."""
inputs = tokenizer(comment, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.sigmoid(outputs.logits).cpu().numpy()[0]
return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
def visualize_toxicity(comment):
"""Generates a bar chart showing toxicity levels."""
scores = predict_toxicity(comment)
# Create bar chart
plt.figure(figsize=(6, 4))
plt.bar(scores.keys(), scores.values(), color=['blue', 'red', 'green', 'purple', 'orange', 'brown'])
plt.ylim(0, 1)
plt.ylabel("Toxicity Score")
plt.title("Toxicity Analysis")
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)
# Save plot to display in Gradio
plt.savefig("toxicity_plot.png")
plt.close()
return "toxicity_plot.png"
# Gradio interface
demo = gr.Interface(
fn=visualize_toxicity,
inputs=gr.Textbox(label="Enter a comment:"),
outputs=gr.Image(type="filepath", label="Toxicity Scores"),
title="Toxicity Detection with RogueBERT",
description="Enter a comment to analyze its toxicity levels. The results will be displayed as a bar chart."
)
demo.launch()