aeresd commited on
Commit
40382a6
·
verified ·
1 Parent(s): 896a453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -64
app.py CHANGED
@@ -6,7 +6,7 @@ import pytesseract
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
- # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
@@ -16,7 +16,7 @@ emoji_model = AutoModelForCausalLM.from_pretrained(
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
- # ✅ Step 2: 可选择的冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
@@ -26,18 +26,23 @@ model_options = {
26
  # ✅ 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
- # ✅ 侧边栏:模型选择
30
  with st.sidebar:
31
  st.header("🧠 Configuration")
32
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
33
  selected_model_id = model_options[selected_model]
34
- classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
35
 
36
  # 初始化历史记录
37
  if "history" not in st.session_state:
38
  st.session_state.history = []
39
 
40
- # 分类函数
41
  def classify_emoji_text(text: str):
42
  prompt = f"输入:{text}\n输出:"
43
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
@@ -46,27 +51,63 @@ def classify_emoji_text(text: str):
46
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
47
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
48
 
49
- result = classifier(translated_text)[0]
50
- label = result["label"]
51
- score = result["score"]
52
- reasoning = (
53
- f"The sentence was flagged as '{label}' due to potentially offensive phrases. "
54
- "Consider replacing emotionally charged, ambiguous, or abusive terms."
55
- )
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  st.session_state.history.append({
58
- "text": text,
59
  "translated": translated_text,
60
  "label": label,
61
  "score": score,
62
- "reason": reasoning
 
63
  })
64
- return translated_text, label, score, reasoning
65
 
66
- # 主页面:输入与分析共存
67
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
68
 
69
- # 文本输入
70
  st.subheader("1. 输入与分类")
71
  default_text = "你是🐷"
72
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
@@ -74,7 +115,7 @@ text = st.text_area("Enter sentence with emojis:", value=default_text, height=15
74
  if st.button("🚦 Analyze Text"):
75
  with st.spinner("🔍 Processing..."):
76
  try:
77
- translated, label, score, reason = classify_emoji_text(text)
78
  st.markdown("**Translated sentence:**")
79
  st.code(translated, language="text")
80
  st.markdown(f"**Prediction:** {label}")
@@ -84,7 +125,7 @@ if st.button("🚦 Analyze Text"):
84
  except Exception as e:
85
  st.error(f"❌ An error occurred:\n{e}")
86
 
87
- # 图片上传与 OCR
88
  st.markdown("---")
89
  st.subheader("2. 图片 OCR & 分类")
90
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
@@ -96,7 +137,7 @@ if uploaded_file:
96
  if ocr_text:
97
  st.markdown("**Extracted Text:**")
98
  st.code(ocr_text)
99
- translated, label, score, reason = classify_emoji_text(ocr_text)
100
  st.markdown("**Translated sentence:**")
101
  st.code(translated, language="text")
102
  st.markdown(f"**Prediction:** {label}")
@@ -110,7 +151,7 @@ if uploaded_file:
110
  st.markdown("---")
111
  st.subheader("3. Violation Analysis Dashboard")
112
  if st.session_state.history:
113
- # 展示历史记录
114
  df = pd.DataFrame(st.session_state.history)
115
  st.markdown("### 🧾 Offensive Terms & Suggestions")
116
  for item in st.session_state.history:
@@ -119,50 +160,21 @@ if st.session_state.history:
119
  st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
120
  st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
121
 
122
- # 雷达图
 
123
  radar_df = pd.DataFrame({
124
- "Category": ["Insult","Abuse","Discrimination","Hate Speech","Vulgarity"],
125
- "Score": [0.7,0.4,0.3,0.5,0.6]
126
  })
127
- radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
128
- radar_fig.update_traces(line_color='black')
 
 
 
 
 
 
 
129
  st.plotly_chart(radar_fig)
130
-
131
- # —— 新增:单词级冒犯性相关性分析 —— #
132
- st.markdown("### 🧬 Word-level Offensive Correlation")
133
-
134
- # 取最近一次翻译文本,按空格拆分单词
135
- last_translated_text = st.session_state.history[-1]["translated"]
136
- words = last_translated_text.split()
137
-
138
- # 对每个单词进行分类并收集分数
139
- word_scores = []
140
- for word in words:
141
- try:
142
- res = classifier(word)[0]
143
- word_scores.append({
144
- "Word": word,
145
- "Label": res["label"],
146
- "Score": res["score"]
147
- })
148
- except Exception:
149
- continue
150
-
151
- if word_scores:
152
- word_df = pd.DataFrame(word_scores)
153
- word_df = word_df.sort_values(by="Score", ascending=False).reset_index(drop=True)
154
-
155
- max_display = 5
156
- # Streamlit 1.22+ 支持 st.toggle,若版本不支持可改用 checkbox
157
- show_more = st.toggle("Show more words", value=False)
158
-
159
- display_df = word_df if show_more else word_df.head(max_display)
160
- # 隐藏边框并渲染 HTML 表格
161
- st.markdown(
162
- display_df.to_html(index=False, border=0),
163
- unsafe_allow_html=True
164
- )
165
- else:
166
- st.info("❕ No word-level analysis available.")
167
  else:
168
  st.info("⚠️ No classification data available yet.")
 
6
  import pandas as pd
7
  import plotly.express as px
8
 
9
+ # ✅ Step 1: Emoji 翻译模型
10
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
11
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
12
  emoji_model = AutoModelForCausalLM.from_pretrained(
 
16
  ).to("cuda" if torch.cuda.is_available() else "cpu")
17
  emoji_model.eval()
18
 
19
+ # ✅ Step 2: 冒犯性文本识别模型
20
  model_options = {
21
  "Toxic-BERT": "unitary/toxic-bert",
22
  "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
 
26
  # ✅ 页面配置
27
  st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
28
 
29
+ # ✅ 侧边栏配置
30
  with st.sidebar:
31
  st.header("🧠 Configuration")
32
  selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
33
  selected_model_id = model_options[selected_model]
34
+ classifier = pipeline(
35
+ "text-classification",
36
+ model=selected_model_id,
37
+ device=0 if torch.cuda.is_available() else -1,
38
+ return_all_scores=True
39
+ )
40
 
41
  # 初始化历史记录
42
  if "history" not in st.session_state:
43
  st.session_state.history = []
44
 
45
+ # 分类函数(优化版)
46
  def classify_emoji_text(text: str):
47
  prompt = f"输入:{text}\n输出:"
48
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
 
51
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
52
  translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
53
 
54
+ # 获取所有分类结果
55
+ all_results = classifier(translated_text)
56
+
57
+ # 雷达图类别映射规则
58
+ radar_categories = ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"]
59
+ radar_scores = {category: 0.0 for category in radar_categories}
60
+
61
+ # 模型特定映射规则
62
+ model_mappings = {
63
+ "Toxic-BERT": {
64
+ "toxic": "Vulgarity",
65
+ "severe_toxic": "Abuse",
66
+ "obscene": "Vulgarity",
67
+ "threat": "Hate Speech",
68
+ "insult": "Insult",
69
+ "identity_hate": "Discrimination"
70
+ },
71
+ "Roberta Offensive": {
72
+ "offensive": ["Insult", "Abuse", "Vulgarity"]
73
+ }
74
+ }
75
+
76
+ # 动态生成雷达分数
77
+ for result in all_results:
78
+ label = result['label']
79
+ score = result['score']
80
+
81
+ if selected_model == "Toxic-BERT":
82
+ mapped_category = model_mappings["Toxic-BERT"].get(label)
83
+ if mapped_category and score > radar_scores[mapped_category]:
84
+ radar_scores[mapped_category] = score
85
+ elif selected_model == "Roberta Offensive" and label == "offensive":
86
+ for category in model_mappings["Roberta Offensive"]["offensive"]:
87
+ if score > radar_scores[category]:
88
+ radar_scores[category] = score
89
+
90
+ # 获取主要分类结果
91
+ primary_result = max(all_results, key=lambda x: x['score'])
92
+ label = primary_result["label"]
93
+ score = primary_result["score"]
94
+ reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
95
+
96
+ # 存储到历史记录
97
  st.session_state.history.append({
98
+ "text": text,
99
  "translated": translated_text,
100
  "label": label,
101
  "score": score,
102
+ "reason": reasoning,
103
+ "radar_scores": radar_scores
104
  })
105
+ return translated_text, label, score, reasoning, radar_scores
106
 
107
+ # 主界面
108
  st.title("🚨 Emoji Offensive Text Detector & Analysis Dashboard")
109
 
110
+ # 文本输入分析
111
  st.subheader("1. 输入与分类")
112
  default_text = "你是🐷"
113
  text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
 
115
  if st.button("🚦 Analyze Text"):
116
  with st.spinner("🔍 Processing..."):
117
  try:
118
+ translated, label, score, reason, radar = classify_emoji_text(text)
119
  st.markdown("**Translated sentence:**")
120
  st.code(translated, language="text")
121
  st.markdown(f"**Prediction:** {label}")
 
125
  except Exception as e:
126
  st.error(f"❌ An error occurred:\n{e}")
127
 
128
+ # 图片分析
129
  st.markdown("---")
130
  st.subheader("2. 图片 OCR & 分类")
131
  uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg","jpeg","png"])
 
137
  if ocr_text:
138
  st.markdown("**Extracted Text:**")
139
  st.code(ocr_text)
140
+ translated, label, score, reason, radar = classify_emoji_text(ocr_text)
141
  st.markdown("**Translated sentence:**")
142
  st.code(translated, language="text")
143
  st.markdown(f"**Prediction:** {label}")
 
151
  st.markdown("---")
152
  st.subheader("3. Violation Analysis Dashboard")
153
  if st.session_state.history:
154
+ # 历史记录展示
155
  df = pd.DataFrame(st.session_state.history)
156
  st.markdown("### 🧾 Offensive Terms & Suggestions")
157
  for item in st.session_state.history:
 
160
  st.markdown(f" - ❗ **Label:** {item['label']} with **{item['score']:.2%}** confidence")
161
  st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
162
 
163
+ # 动态生成雷达图
164
+ latest_radar = st.session_state.history[-1]["radar_scores"]
165
  radar_df = pd.DataFrame({
166
+ "Category": latest_radar.keys(),
167
+ "Score": latest_radar.values()
168
  })
169
+ radar_fig = px.line_polar(
170
+ radar_df,
171
+ r='Score',
172
+ theta='Category',
173
+ line_close=True,
174
+ title="⚠️ Risk Radar by Category",
175
+ range_r=[0,1]
176
+ )
177
+ radar_fig.update_traces(fill='toself', line_color='red')
178
  st.plotly_chart(radar_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  else:
180
  st.info("⚠️ No classification data available yet.")