aeresd commited on
Commit
e21b576
·
verified ·
1 Parent(s): 40382a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -110
app.py CHANGED
@@ -5,8 +5,10 @@ from PIL import Image
5
  import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
 
 
8
 
9
- # ✅ Step 1: Emoji 翻译模型
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
@@ -16,7 +18,7 @@ emoji_model = AutoModelForCausalLM.from_pretrained(
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
- # ✅ Step 2: 冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
@@ -26,106 +28,89 @@ model_options = {
26
  # ✅ 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
- # ✅ 侧边栏配置
30
  with st.sidebar:
31
  st.header("🧠 Configuration")
32
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
33
  selected_model_id = model_options[selected_model]
34
- classifier = pipeline(
35
- "text-classification",
36
- model=selected_model_id,
37
- device=0 if torch.cuda.is_available() else -1,
38
- return_all_scores=True
39
- )
40
 
41
  # 初始化历史记录
42
  if "history" not in st.session_state:
43
  st.session_state.history = []
44
 
45
- # 分类函数(优化版)
46
- def classify_emoji_text(text: str):
47
- prompt = f"输入:{text}\n输出:"
48
- input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
49
- with torch.no_grad():
50
- output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
51
- decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
52
- translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
53
-
54
- # 获取所有分类结果
55
- all_results = classifier(translated_text)
56
-
57
- # 雷达图类别映射规则
58
- radar_categories = ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"]
59
- radar_scores = {category: 0.0 for category in radar_categories}
60
-
61
- # 模型特定映射规则
62
- model_mappings = {
63
- "Toxic-BERT": {
64
- "toxic": "Vulgarity",
65
- "severe_toxic": "Abuse",
66
- "obscene": "Vulgarity",
67
- "threat": "Hate Speech",
68
- "insult": "Insult",
69
- "identity_hate": "Discrimination"
70
- },
71
- "Roberta Offensive": {
72
- "offensive": ["Insult", "Abuse", "Vulgarity"]
73
- }
74
- }
75
 
76
- # 动态生成雷达分数
77
- for result in all_results:
78
- label = result['label']
79
- score = result['score']
80
-
81
- if selected_model == "Toxic-BERT":
82
- mapped_category = model_mappings["Toxic-BERT"].get(label)
83
- if mapped_category and score > radar_scores[mapped_category]:
84
- radar_scores[mapped_category] = score
85
- elif selected_model == "Roberta Offensive" and label == "offensive":
86
- for category in model_mappings["Roberta Offensive"]["offensive"]:
87
- if score > radar_scores[category]:
88
- radar_scores[category] = score
89
-
90
- # 获取主要分类结果
91
- primary_result = max(all_results, key=lambda x: x['score'])
92
- label = primary_result["label"]
93
- score = primary_result["score"]
94
- reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
95
-
96
- # 存储到历史记录
97
- st.session_state.history.append({
98
- "text": text,
99
- "translated": translated_text,
100
- "label": label,
101
- "score": score,
102
- "reason": reasoning,
103
- "radar_scores": radar_scores
104
- })
105
- return translated_text, label, score, reasoning, radar_scores
106
-
107
- # 主界面
 
 
 
 
 
108
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
109
 
110
- # 文本输入分析
111
  st.subheader("1. 输入与分类")
112
- default_text = "你是🐷"
113
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
114
 
115
  if st.button("🚦 Analyze Text"):
116
  with st.spinner("🔍 Processing..."):
117
  try:
118
- translated, label, score, reason, radar = classify_emoji_text(text)
119
- st.markdown("**Translated sentence:**")
120
- st.code(translated, language="text")
121
- st.markdown(f"**Prediction:** {label}")
122
- st.markdown(f"**Confidence Score:** {score:.2%}")
123
- st.markdown("**Model Explanation:**")
124
- st.info(reason)
 
 
 
125
  except Exception as e:
126
  st.error(f"❌ An error occurred:\n{e}")
127
 
128
- # 图片分析
129
  st.markdown("---")
130
  st.subheader("2. 图片 OCR & 分类")
131
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
@@ -137,13 +122,12 @@ if uploaded_file:
137
  if ocr_text:
138
  st.markdown("**Extracted Text:**")
139
  st.code(ocr_text)
140
- translated, label, score, reason, radar = classify_emoji_text(ocr_text)
141
- st.markdown("**Translated sentence:**")
142
- st.code(translated, language="text")
143
- st.markdown(f"**Prediction:** {label}")
144
- st.markdown(f"**Confidence Score:** {score:.2%}")
145
- st.markdown("**Model Explanation:**")
146
- st.info(reason)
147
  else:
148
  st.info("⚠️ No text detected in the image.")
149
 
@@ -151,30 +135,24 @@ if uploaded_file:
151
  st.markdown("---")
152
  st.subheader("3. Violation Analysis Dashboard")
153
  if st.session_state.history:
154
- # 历史记录展示
155
  df = pd.DataFrame(st.session_state.history)
156
- st.markdown("### 🧾 Offensive Terms & Suggestions")
 
157
  for item in st.session_state.history:
158
- st.markdown(f"- 🔹 **Input:** {item['text']}")
159
- st.markdown(f" - **Translated:** {item['translated']}")
160
- st.markdown(f" - **Label:** {item['label']} with **{item['score']:.2%}** confidence")
161
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
162
-
163
- # 动态生成雷达图
164
- latest_radar = st.session_state.history[-1]["radar_scores"]
165
- radar_df = pd.DataFrame({
166
- "Category": latest_radar.keys(),
167
- "Score": latest_radar.values()
168
- })
169
- radar_fig = px.line_polar(
170
- radar_df,
171
- r='Score',
172
- theta='Category',
173
- line_close=True,
174
- title="⚠️ Risk Radar by Category",
175
- range_r=[0,1]
176
- )
177
- radar_fig.update_traces(fill='toself', line_color='red')
178
  st.plotly_chart(radar_fig)
179
  else:
180
  st.info("⚠️ No classification data available yet.")
 
5
  import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
8
+ import re
9
+ from collections import defaultdict
10
 
11
+ # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
12
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
13
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
14
  emoji_model = AutoModelForCausalLM.from_pretrained(
 
18
  ).to("cuda" if torch.cuda.is_available() else "cpu")
19
  emoji_model.eval()
20
 
21
+ # ✅ Step 2: 可选择的冒犯性文本识别模型
22
  model_options = {
23
  "Toxic-BERT": "unitary/toxic-bert",
24
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
 
28
  # ✅ 页面配置
29
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
30
 
31
+ # ✅ 侧边栏:模型选择
32
  with st.sidebar:
33
  st.header("🧠 Configuration")
34
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
35
  selected_model_id = model_options[selected_model]
36
+ classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
37
 
38
  # 初始化历史记录
39
  if "history" not in st.session_state:
40
  st.session_state.history = []
41
 
42
+ # 映射 label 到雷達圖分類(可依模型微調)
43
+ label_to_category = {
44
+ "toxic": "Abuse",
45
+ "offensive": "Insult",
46
+ "insult": "Insult",
47
+ "threat": "Hate Speech",
48
+ "obscene": "Vulgarity",
49
+ "hate": "Hate Speech",
50
+ "discrimination": "Discrimination"
51
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # ✅ 冒犯性分析函數(逐元素)
54
+ def classify_text_elements(text: str):
55
+ elements = re.split(r"[,。,、!!??\s\n]", text)
56
+ elements = [e for e in elements if e.strip()]
57
+
58
+ results = []
59
+ radar_scores = defaultdict(float)
60
+
61
+ for element in elements:
62
+ prompt = f"输入:{element}\n输出:"
63
+ input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
64
+ with torch.no_grad():
65
+ output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
66
+ decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
67
+ translated = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
68
+
69
+ classification = classifier(translated)[0]
70
+ label = classification["label"].lower()
71
+ score = classification["score"]
72
+
73
+ category = label_to_category.get(label, "Others")
74
+ radar_scores[category] += score
75
+
76
+ reasoning = f"'{element}' was flagged as '{label}' → '{category}' due to potential offensiveness."
77
+ results.append({
78
+ "text": element,
79
+ "translated": translated,
80
+ "label": label,
81
+ "category": category,
82
+ "score": score,
83
+ "reason": reasoning
84
+ })
85
+
86
+ st.session_state.history.extend(results)
87
+ return results, radar_scores
88
+
89
+ # 主页面
90
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
91
 
92
+ # 文本输入
93
  st.subheader("1. 输入与分类")
94
+ default_text = "你是🐷,太垃圾了,滚开!"
95
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
96
 
97
  if st.button("🚦 Analyze Text"):
98
  with st.spinner("🔍 Processing..."):
99
  try:
100
+ analysis_results, radar_scores = classify_text_elements(text)
101
+
102
+ st.markdown("### ✨ Element-wise Classification")
103
+ for item in analysis_results:
104
+ st.markdown(f"- 🔹 **Input:** {item['text']}")
105
+ st.markdown(f" - ✨ **Translated:** {item['translated']}")
106
+ st.markdown(f" - ❗ **Label:** {item['label']} → **{item['category']}** ({item['score']:.2%})")
107
+ st.markdown(f" - 🔧 **Reasoning:** {item['reason']}")
108
+
109
+ st.success("✅ Analysis complete!")
110
  except Exception as e:
111
  st.error(f"❌ An error occurred:\n{e}")
112
 
113
+ # 图片上传与 OCR
114
  st.markdown("---")
115
  st.subheader("2. 图片 OCR & 分类")
116
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
 
122
  if ocr_text:
123
  st.markdown("**Extracted Text:**")
124
  st.code(ocr_text)
125
+ analysis_results, radar_scores = classify_text_elements(ocr_text)
126
+ for item in analysis_results:
127
+ st.markdown(f"- 🔹 **Input:** {item['text']}")
128
+ st.markdown(f" - ✨ **Translated:** {item['translated']}")
129
+ st.markdown(f" - ❗ **Label:** {item['label']} → **{item['category']}** ({item['score']:.2%})")
130
+ st.markdown(f" - 🔧 **Reasoning:** {item['reason']}")
 
131
  else:
132
  st.info("⚠️ No text detected in the image.")
133
 
 
135
  st.markdown("---")
136
  st.subheader("3. Violation Analysis Dashboard")
137
  if st.session_state.history:
 
138
  df = pd.DataFrame(st.session_state.history)
139
+
140
+ st.markdown("### 🧾 Offense History Summary")
141
  for item in st.session_state.history:
142
+ st.markdown(f"- **Input:** {item['text']}")
143
+ st.markdown(f" - 🔠 Translated: {item['translated']}")
144
+ st.markdown(f" - 🏷️ Label: {item['label']} {item['category']}, Score: {item['score']:.2%}")
145
+
146
+ # 累积雷达分数
147
+ category_list = ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"]
148
+ radar_data = {
149
+ "Category": category_list,
150
+ "Score": [min(st.session_state.history.count(c)/len(st.session_state.history), 1.0)
151
+ if c in radar_scores else 0.0 for c in category_list]
152
+ }
153
+ radar_df = pd.DataFrame(radar_data)
154
+ radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
155
+ radar_fig.update_traces(line_color='black')
 
 
 
 
 
 
156
  st.plotly_chart(radar_fig)
157
  else:
158
  st.info("⚠️ No classification data available yet.")