KavinduHansaka commited on
Commit
37d03fb
·
verified ·
1 Parent(s): 8d41ba9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -45
app.py CHANGED
@@ -3,69 +3,81 @@ import pandas as pd
3
  from detoxify import Detoxify
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
- import numpy as np
7
  import io
8
 
9
- # Load Detoxify multilingual model
10
- tox_model = Detoxify('multilingual')
11
-
12
- # Load AI detector model
13
  ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/roberta-base-openai-detector")
14
  ai_model = AutoModelForSequenceClassification.from_pretrained("openai-community/roberta-base-openai-detector")
15
 
16
  # Thresholds
17
  TOXICITY_THRESHOLD = 0.7
18
- AI_THRESHOLD = 0.5 # If >0.5, it's likely AI-generated
19
 
20
- def detect_ai_generated(text):
21
  with torch.no_grad():
22
  inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
23
  logits = ai_model(**inputs).logits
24
- probs = torch.sigmoid(logits).squeeze().item()
25
- return round(probs, 4)
26
 
27
- def process_input(file):
28
- df = pd.read_csv(file.name)
29
- if 'comment' not in df.columns:
30
- return "CSV must contain a 'comment' column."
31
-
32
- comments = df['comment'].astype(str).tolist()
33
- tox_results = tox_model.predict(comments)
34
- tox_df = pd.DataFrame(tox_results, index=comments).round(4)
35
 
36
- # Format columns
37
- tox_df.columns = [col.replace("_", " ").title().replace(" ", "_") for col in tox_df.columns]
38
- tox_df.columns = [col.replace("_", " ") for col in tox_df.columns]
 
39
 
40
- # Add warnings
41
- tox_df["⚠️ Warning"] = tox_df.apply(lambda row: "⚠️ High Risk" if any(score > TOXICITY_THRESHOLD for score in row) else "✅ Safe", axis=1)
42
 
43
- # Add AI detection
44
- tox_df["🧪 AI Probability"] = [detect_ai_generated(c) for c in tox_df.index]
45
- tox_df["🧪 AI Detection"] = tox_df["🧪 AI Probability"].apply(lambda x: "🤖 Likely AI" if x > AI_THRESHOLD else "🧍 Human")
 
 
 
 
 
46
 
47
- # Store downloadable CSV
48
- csv_data = tox_df.copy()
49
- csv_data.insert(0, "Comment", tox_df.index)
50
- csv_bytes = csv_data.to_csv(index=False).encode()
51
- return tox_df, ("toxicity_report.csv", csv_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Gradio UI
54
- upload = gr.File(label="📥 Upload .CSV (Must contain 'comment' column)")
55
- output_table = gr.Dataframe(label="📊 Predictions (Multilingual + AI Detection)")
56
- download = gr.File(label="📤 Download Predictions")
 
57
 
58
- app = gr.Interface(
59
- fn=process_input,
60
- inputs=upload,
61
- outputs=[output_table, download],
62
- title="🌍 Toxic Comment Classifier + AI Text Detector",
63
- description="""
64
- 📥 Upload a .csv file with a 'comment' column.
65
- 🔍 Each comment will be scored for toxicity (Multilingual model) and AI-generation probability (RoBERTa-based).
66
- 📤 Download the full report as .csv.
67
- """
68
- )
69
 
70
  if __name__ == "__main__":
71
  app.launch()
 
3
  from detoxify import Detoxify
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
 
6
  import io
7
 
8
+ # Load models
9
+ tox_model = Detoxify('multilingual') # 🌍 Multilingual toxicity classifier
 
 
10
  ai_tokenizer = AutoTokenizer.from_pretrained("openai-community/roberta-base-openai-detector")
11
  ai_model = AutoModelForSequenceClassification.from_pretrained("openai-community/roberta-base-openai-detector")
12
 
13
  # Thresholds
14
  TOXICITY_THRESHOLD = 0.7
15
+ AI_THRESHOLD = 0.5
16
 
17
+ def detect_ai(text):
18
  with torch.no_grad():
19
  inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
20
  logits = ai_model(**inputs).logits
21
+ prob = torch.sigmoid(logits).squeeze().item()
22
+ return round(prob, 4)
23
 
24
+ def classify_comments(comment_list):
25
+ results = tox_model.predict(comment_list)
26
+ df = pd.DataFrame(results, index=comment_list).round(4)
27
+
28
+ # Capitalize columns
29
+ df.columns = [col.replace("_", " ").title().replace(" ", "_") for col in df.columns]
30
+ df.columns = [col.replace("_", " ") for col in df.columns]
 
31
 
32
+ # Add warning & AI detection
33
+ df["⚠️ Warning"] = df.apply(lambda row: "⚠️ High Risk" if any(score > TOXICITY_THRESHOLD for score in row) else "✅ Safe", axis=1)
34
+ df["🧪 AI Probability"] = [detect_ai(c) for c in df.index]
35
+ df["🧪 AI Detection"] = df["🧪 AI Probability"].apply(lambda x: "🤖 Likely AI" if x > AI_THRESHOLD else "🧍 Human")
36
 
37
+ return df
 
38
 
39
+ def classify_from_textbox(text_input):
40
+ comment_list = [c.strip() for c in text_input.strip().split('\n') if c.strip()]
41
+ if not comment_list:
42
+ return "Please enter at least one comment.", None
43
+ df = classify_comments(comment_list)
44
+ csv_data = df.copy()
45
+ csv_data.insert(0, "Comment", df.index)
46
+ return df, ("toxicity_predictions.csv", csv_data.to_csv(index=False).encode())
47
 
48
+ def classify_from_csv(file_obj):
49
+ df = pd.read_csv(file_obj.name)
50
+ if 'comment' not in df.columns:
51
+ return "CSV must contain a 'comment' column.", None
52
+ comment_list = df['comment'].astype(str).tolist()
53
+ df = classify_comments(comment_list)
54
+ csv_data = df.copy()
55
+ csv_data.insert(0, "Comment", df.index)
56
+ return df, ("toxicity_predictions.csv", csv_data.to_csv(index=False).encode())
57
+
58
+ # Gradio Interface
59
+ text_input = gr.Textbox(lines=8, label="💬 Paste Comments (one per line)")
60
+ csv_input = gr.File(label="📥 Or Upload .CSV with 'comment' column")
61
+ output_table = gr.Dataframe(label="📊 Predictions")
62
+ download_button = gr.File(label="📤 Download CSV")
63
+
64
+ with gr.Blocks(title="Toxicity & AI Comment Detector") as app:
65
+ gr.Markdown("## 🌍 Toxic Comment & AI Detector\nDetects multilingual toxicity and whether the text is AI-generated.")
66
+
67
+ with gr.Tab("📝 Paste Text"):
68
+ text = text_input
69
+ btn1 = gr.Button("Analyze Text Comments")
70
+ output1 = output_table
71
+ download1 = download_button
72
 
73
+ with gr.Tab("📁 Upload CSV"):
74
+ csv = csv_input
75
+ btn2 = gr.Button("Analyze CSV File")
76
+ output2 = output_table
77
+ download2 = download_button
78
 
79
+ btn1.click(fn=classify_from_textbox, inputs=text, outputs=[output1, download1])
80
+ btn2.click(fn=classify_from_csv, inputs=csv, outputs=[output2, download2])
 
 
 
 
 
 
 
 
 
81
 
82
  if __name__ == "__main__":
83
  app.launch()