HyperX-Sentience commited on
Commit
4acd2a4
·
verified ·
1 Parent(s): 236e362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -35
app.py CHANGED
@@ -1,61 +1,75 @@
1
  import os
2
  os.system("pip install torch transformers gradio matplotlib")
3
 
 
 
4
 
5
  import torch
6
  import gradio as gr
7
- import numpy as np
8
  import matplotlib.pyplot as plt
9
- import pandas as pd
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
 
12
- torch.set_num_threads(torch.get_num_threads())
 
 
 
13
 
14
- # Load the trained model and tokenizer from Hugging Face Hub
15
- model_path = "HyperX-Sentience/RogueBERT-Toxicity-85K"
16
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
17
- tokenizer = AutoTokenizer.from_pretrained(model_path)
18
-
19
- # Move the model to CUDA if available
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
22
 
23
- # Define toxicity labels
24
  labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
25
 
 
26
  def predict_toxicity(comment):
27
- """Predicts the toxicity levels of a given comment."""
28
- inputs = tokenizer(comment, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
29
  inputs = {key: val.to(device) for key, val in inputs.items()}
30
 
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
- probabilities = torch.sigmoid(outputs.logits).cpu().numpy()[0]
 
34
 
35
- return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
 
36
 
37
- def format_toxicity_data(comment):
38
- """Formats the toxicity scores for a modern bar graph."""
39
- scores = predict_toxicity(comment)
40
- df = pd.DataFrame({"Category": list(scores.keys()), "Score": list(scores.values())})
41
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Gradio interface
44
  demo = gr.Interface(
45
- fn=format_toxicity_data,
46
- inputs=gr.Textbox(label="Enter a comment:"),
47
- outputs=gr.BarPlot(
48
- value=None,
49
- x="Category",
50
- y="Score",
51
- title="Toxicity Analysis",
52
- y_lim=[0, 1],
53
- color="blue",
54
- label="Toxicity Scores",
55
- interactive=False
56
- ),
57
- title="Toxicity Detection with RogueBERT",
58
- description="Enter a comment to analyze its toxicity levels. The results will be displayed as a modern bar chart."
59
  )
60
 
61
- demo.launch()
 
 
 
1
  import os
2
  os.system("pip install torch transformers gradio matplotlib")
3
 
4
+ # Install required packages
5
+ # !pip install torch transformers gradio matplotlib
6
 
7
  import torch
8
  import gradio as gr
 
9
  import matplotlib.pyplot as plt
10
+ import numpy as np
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
 
13
+ # Load model and tokenizer from Hugging Face Hub
14
+ model_name = "HyperX-Sentience/RogueBERT-Toxicity-85K"
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
+ # Move model to CUDA if available
 
 
 
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
 
22
+ # Toxicity category labels
23
  labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
24
 
25
+ # Function to predict toxicity
26
  def predict_toxicity(comment):
27
+ inputs = tokenizer([comment], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
 
28
  inputs = {key: val.to(device) for key, val in inputs.items()}
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
+ logits = outputs.logits
33
+ probabilities = torch.sigmoid(logits).cpu().numpy()[0]
34
 
35
+ toxicity_scores = {label: float(probabilities[i]) for i, label in enumerate(labels)}
36
+ return toxicity_scores
37
 
38
+ # Function to create a bar chart
39
+ def plot_toxicity(comment):
40
+ toxicity_scores = predict_toxicity(comment)
41
+ categories = list(toxicity_scores.keys())
42
+ scores = list(toxicity_scores.values())
43
+
44
+ plt.figure(figsize=(8, 5), facecolor='black')
45
+ ax = plt.gca()
46
+ ax.set_facecolor('black')
47
+ bars = plt.bar(categories, scores, color='#20B2AA', edgecolor='white') # Sea green
48
+
49
+ plt.xticks(color='white', fontsize=12)
50
+ plt.yticks(color='white', fontsize=12)
51
+ plt.title("Toxicity Score Analysis", color='white', fontsize=14)
52
+ plt.ylim(0, 1)
53
+
54
+ for bar in bars:
55
+ yval = bar.get_height()
56
+ plt.text(bar.get_x() + bar.get_width()/2, yval + 0.02, f'{yval:.2f}', ha='center', color='white', fontsize=10)
57
+
58
+ plt.tight_layout()
59
+ plt.savefig("toxicity_chart.png", facecolor='black')
60
+ plt.close()
61
+
62
+ return "toxicity_chart.png"
63
 
64
+ # Gradio UI
65
  demo = gr.Interface(
66
+ fn=plot_toxicity,
67
+ inputs=gr.Textbox(label="Enter a comment"),
68
+ outputs=gr.Image(type="file", label="Toxicity Analysis"),
69
+ title="Toxicity Detector",
70
+ description="Enter a comment to analyze its toxicity scores across different categories.",
 
 
 
 
 
 
 
 
 
71
  )
72
 
73
+ # Launch the Gradio app
74
+ if __name__ == "__main__":
75
+ demo.launch()