File size: 2,643 Bytes
a13b271 4acd2a4 90fb56e 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 0475377 4acd2a4 e0750d5 4acd2a4 e0750d5 4acd2a4 e0750d5 6f46359 e0750d5 4acd2a4 e0750d5 6f46359 e0750d5 6f46359 4acd2a4 4f398a9 4acd2a4 0475377 4acd2a4 aae82fc 4acd2a4 0475377 4acd2a4 e0750d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
os.system("pip install torch transformers gradio matplotlib")
# Install required packages
# !pip install torch transformers gradio matplotlib
import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load model and tokenizer from Hugging Face Hub
model_name = "HyperX-Sentience/RogueBERT-Toxicity-85K"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Move model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Toxicity category labels
labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
# Function to predict toxicity
def predict_toxicity(comment):
inputs = tokenizer([comment], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
toxicity_scores = {label: float(probabilities[i]) for i, label in enumerate(labels)}
return toxicity_scores
# Function to create a bar chart
def plot_toxicity(comment):
toxicity_scores = predict_toxicity(comment)
categories = list(toxicity_scores.keys())
scores = list(toxicity_scores.values())
plt.figure(figsize=(12, 7), dpi=300, facecolor='black')
ax = plt.gca()
ax.set_facecolor('black')
bars = plt.bar(categories, scores, color='#20B2AA', edgecolor='white', width=0.5) # Sea green
plt.xticks(color='white', fontsize=14, rotation=25, ha='right')
plt.yticks(color='white', fontsize=14)
plt.title("Toxicity Score Analysis", color='white', fontsize=16)
plt.ylim(0, 1.1)
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 0.03, f'{yval:.2f}', ha='center', color='white', fontsize=12, fontweight='bold')
plt.tight_layout(pad=2)
plt.savefig("toxicity_chart.png", facecolor='black', bbox_inches='tight')
plt.close()
return "toxicity_chart.png"
# Gradio UI
demo = gr.Interface(
fn=plot_toxicity,
inputs=gr.Textbox(label="Enter a comment"),
outputs=gr.Image(type="filepath", label="Toxicity Analysis"),
title="Toxicity Detector",
description="Enter a comment to analyze its toxicity scores across different categories.",
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|