File size: 2,643 Bytes
a13b271
 
 
4acd2a4
 
90fb56e
0475377
 
 
4acd2a4
0475377
 
4acd2a4
 
 
 
0475377
4acd2a4
0475377
 
 
4acd2a4
0475377
 
4acd2a4
0475377
4acd2a4
0475377
 
 
 
4acd2a4
 
0475377
4acd2a4
 
0475377
4acd2a4
 
 
 
 
 
e0750d5
4acd2a4
 
e0750d5
4acd2a4
e0750d5
6f46359
 
e0750d5
4acd2a4
 
 
e0750d5
6f46359
e0750d5
6f46359
4acd2a4
 
 
4f398a9
4acd2a4
0475377
4acd2a4
 
aae82fc
4acd2a4
 
0475377
 
4acd2a4
 
e0750d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
os.system("pip install torch transformers gradio matplotlib")

# Install required packages
# !pip install torch transformers gradio matplotlib

import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model and tokenizer from Hugging Face Hub
model_name = "HyperX-Sentience/RogueBERT-Toxicity-85K"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Move model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Toxicity category labels
labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

# Function to predict toxicity
def predict_toxicity(comment):
    inputs = tokenizer([comment], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.sigmoid(logits).cpu().numpy()[0]
    
    toxicity_scores = {label: float(probabilities[i]) for i, label in enumerate(labels)}
    return toxicity_scores

# Function to create a bar chart
def plot_toxicity(comment):
    toxicity_scores = predict_toxicity(comment)
    categories = list(toxicity_scores.keys())
    scores = list(toxicity_scores.values())
    
    plt.figure(figsize=(12, 7), dpi=300, facecolor='black')
    ax = plt.gca()
    ax.set_facecolor('black')
    bars = plt.bar(categories, scores, color='#20B2AA', edgecolor='white', width=0.5)  # Sea green
    
    plt.xticks(color='white', fontsize=14, rotation=25, ha='right')
    plt.yticks(color='white', fontsize=14)
    plt.title("Toxicity Score Analysis", color='white', fontsize=16)
    plt.ylim(0, 1.1)
    
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.03, f'{yval:.2f}', ha='center', color='white', fontsize=12, fontweight='bold')
    
    plt.tight_layout(pad=2)
    plt.savefig("toxicity_chart.png", facecolor='black', bbox_inches='tight')
    plt.close()
    
    return "toxicity_chart.png"

# Gradio UI
demo = gr.Interface(
    fn=plot_toxicity,
    inputs=gr.Textbox(label="Enter a comment"),
    outputs=gr.Image(type="filepath", label="Toxicity Analysis"),
    title="Toxicity Detector",
    description="Enter a comment to analyze its toxicity scores across different categories.",
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()