Update app.py
Browse files
app.py
CHANGED
@@ -3,69 +3,81 @@ import pandas as pd
|
|
3 |
from detoxify import Detoxify
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
import torch
|
6 |
-
import numpy as np
|
7 |
import io
|
8 |
|
9 |
-
# Load
|
10 |
-
tox_model = Detoxify('multilingual')
|
11 |
-
|
12 |
-
# Load AI detector model
|
13 |
ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/roberta-base-openai-detector")
|
14 |
ai_model = AutoModelForSequenceClassification.from_pretrained("openai-community/roberta-base-openai-detector")
|
15 |
|
16 |
# Thresholds
|
17 |
TOXICITY_THRESHOLD = 0.7
|
18 |
-
AI_THRESHOLD = 0.5
|
19 |
|
20 |
-
def
|
21 |
with torch.no_grad():
|
22 |
inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
23 |
logits = ai_model(**inputs).logits
|
24 |
-
|
25 |
-
return round(
|
26 |
|
27 |
-
def
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
tox_df = pd.DataFrame(tox_results, index=comments).round(4)
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
-
|
41 |
-
tox_df["⚠️ Warning"] = tox_df.apply(lambda row: "⚠️ High Risk" if any(score > TOXICITY_THRESHOLD for score in row) else "✅ Safe", axis=1)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
57 |
|
58 |
-
|
59 |
-
fn=
|
60 |
-
inputs=upload,
|
61 |
-
outputs=[output_table, download],
|
62 |
-
title="🌍 Toxic Comment Classifier + AI Text Detector",
|
63 |
-
description="""
|
64 |
-
📥 Upload a .csv file with a 'comment' column.
|
65 |
-
🔍 Each comment will be scored for toxicity (Multilingual model) and AI-generation probability (RoBERTa-based).
|
66 |
-
📤 Download the full report as .csv.
|
67 |
-
"""
|
68 |
-
)
|
69 |
|
70 |
if __name__ == "__main__":
|
71 |
app.launch()
|
|
|
3 |
from detoxify import Detoxify
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
import torch
|
|
|
6 |
import io
|
7 |
|
8 |
+
# Load models
|
9 |
+
tox_model = Detoxify('multilingual') # 🌍 Multilingual toxicity classifier
|
|
|
|
|
10 |
ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/roberta-base-openai-detector")
|
11 |
ai_model = AutoModelForSequenceClassification.from_pretrained("openai-community/roberta-base-openai-detector")
|
12 |
|
13 |
# Thresholds
|
14 |
TOXICITY_THRESHOLD = 0.7
|
15 |
+
AI_THRESHOLD = 0.5
|
16 |
|
17 |
+
def detect_ai(text):
|
18 |
with torch.no_grad():
|
19 |
inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
20 |
logits = ai_model(**inputs).logits
|
21 |
+
prob = torch.sigmoid(logits).squeeze().item()
|
22 |
+
return round(prob, 4)
|
23 |
|
24 |
+
def classify_comments(comment_list):
|
25 |
+
results = tox_model.predict(comment_list)
|
26 |
+
df = pd.DataFrame(results, index=comment_list).round(4)
|
27 |
+
|
28 |
+
# Capitalize columns
|
29 |
+
df.columns = [col.replace("_", " ").title().replace(" ", "_") for col in df.columns]
|
30 |
+
df.columns = [col.replace("_", " ") for col in df.columns]
|
|
|
31 |
|
32 |
+
# Add warning & AI detection
|
33 |
+
df["⚠️ Warning"] = df.apply(lambda row: "⚠️ High Risk" if any(score > TOXICITY_THRESHOLD for score in row) else "✅ Safe", axis=1)
|
34 |
+
df["🧪 AI Probability"] = [detect_ai(c) for c in df.index]
|
35 |
+
df["🧪 AI Detection"] = df["🧪 AI Probability"].apply(lambda x: "🤖 Likely AI" if x > AI_THRESHOLD else "🧍 Human")
|
36 |
|
37 |
+
return df
|
|
|
38 |
|
39 |
+
def classify_from_textbox(text_input):
|
40 |
+
comment_list = [c.strip() for c in text_input.strip().split('\n') if c.strip()]
|
41 |
+
if not comment_list:
|
42 |
+
return "Please enter at least one comment.", None
|
43 |
+
df = classify_comments(comment_list)
|
44 |
+
csv_data = df.copy()
|
45 |
+
csv_data.insert(0, "Comment", df.index)
|
46 |
+
return df, ("toxicity_predictions.csv", csv_data.to_csv(index=False).encode())
|
47 |
|
48 |
+
def classify_from_csv(file_obj):
|
49 |
+
df = pd.read_csv(file_obj.name)
|
50 |
+
if 'comment' not in df.columns:
|
51 |
+
return "CSV must contain a 'comment' column.", None
|
52 |
+
comment_list = df['comment'].astype(str).tolist()
|
53 |
+
df = classify_comments(comment_list)
|
54 |
+
csv_data = df.copy()
|
55 |
+
csv_data.insert(0, "Comment", df.index)
|
56 |
+
return df, ("toxicity_predictions.csv", csv_data.to_csv(index=False).encode())
|
57 |
+
|
58 |
+
# Gradio Interface
|
59 |
+
text_input = gr.Textbox(lines=8, label="💬 Paste Comments (one per line)")
|
60 |
+
csv_input = gr.File(label="📥 Or Upload .CSV with 'comment' column")
|
61 |
+
output_table = gr.Dataframe(label="📊 Predictions")
|
62 |
+
download_button = gr.File(label="📤 Download CSV")
|
63 |
+
|
64 |
+
with gr.Blocks(title="Toxicity & AI Comment Detector") as app:
|
65 |
+
gr.Markdown("## 🌍 Toxic Comment & AI Detector\nDetects multilingual toxicity and whether the text is AI-generated.")
|
66 |
+
|
67 |
+
with gr.Tab("📝 Paste Text"):
|
68 |
+
text = text_input
|
69 |
+
btn1 = gr.Button("Analyze Text Comments")
|
70 |
+
output1 = output_table
|
71 |
+
download1 = download_button
|
72 |
|
73 |
+
with gr.Tab("📁 Upload CSV"):
|
74 |
+
csv = csv_input
|
75 |
+
btn2 = gr.Button("Analyze CSV File")
|
76 |
+
output2 = output_table
|
77 |
+
download2 = download_button
|
78 |
|
79 |
+
btn1.click(fn=classify_from_textbox, inputs=text, outputs=[output1, download1])
|
80 |
+
btn2.click(fn=classify_from_csv, inputs=csv, outputs=[output2, download2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
app.launch()
|