HyperX-Sentience commited on
Commit
0475377
·
verified ·
1 Parent(s): 3e0d2a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ torch.set_num_threads(torch.get_num_threads())
8
+
9
+ # Load the trained model and tokenizer from Hugging Face Hub
10
+ model_path = "HyperX-Sentience/RogueBERT-Toxicity-85K"
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
13
+
14
+ # Move the model to CUDA if available
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ # Define toxicity labels
19
+ labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
20
+
21
+ def predict_toxicity(comment):
22
+ """Predicts the toxicity levels of a given comment."""
23
+ inputs = tokenizer(comment, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
24
+ inputs = {key: val.to(device) for key, val in inputs.items()}
25
+
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ probabilities = torch.sigmoid(outputs.logits).cpu().numpy()[0]
29
+
30
+ return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
31
+
32
+ def visualize_toxicity(comment):
33
+ """Generates a bar chart showing toxicity levels."""
34
+ scores = predict_toxicity(comment)
35
+
36
+ # Create bar chart
37
+ plt.figure(figsize=(6, 4))
38
+ plt.bar(scores.keys(), scores.values(), color=['blue', 'red', 'green', 'purple', 'orange', 'brown'])
39
+ plt.ylim(0, 1)
40
+ plt.ylabel("Toxicity Score")
41
+ plt.title("Toxicity Analysis")
42
+ plt.xticks(rotation=45)
43
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
44
+
45
+ # Save plot to display in Gradio
46
+ plt.savefig("toxicity_plot.png")
47
+ plt.close()
48
+
49
+ return "toxicity_plot.png"
50
+
51
+ # Gradio interface
52
+ demo = gr.Interface(
53
+ fn=visualize_toxicity,
54
+ inputs=gr.Textbox(label="Enter a comment:"),
55
+ outputs=gr.Image(type="file", label="Toxicity Scores"),
56
+ title="Toxicity Detection with RogueBERT",
57
+ description="Enter a comment to analyze its toxicity levels. The results will be displayed as a bar chart."
58
+ )
59
+
60
+ demo.launch()