cordwainersmith commited on
Commit
277ab09
·
1 Parent(s): 01d7fe4

Add application file

Browse files
Files changed (1) hide show
  1. app.py +448 -0
app.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
4
+ import time
5
+ import json
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ import os
9
+ from typing import List, Dict, Tuple
10
+ import re
11
+
12
+ # Constants
13
+ MODELS = {
14
+ "GolemPII XLM-RoBERTa v1": "CordwainerSmith/GolemPII-xlm-roberta-v1",
15
+ }
16
+
17
+
18
+ ENTITY_COLORS = {
19
+ "PHONE_NUM": "#FF9999",
20
+ "ID_NUM": "#99FF99",
21
+ "CC_NUM": "#9999FF",
22
+ "BANK_ACCOUNT_NUM": "#FFFF99",
23
+ "FIRST_NAME": "#FF99FF",
24
+ "LAST_NAME": "#99FFFF",
25
+ "CITY": "#FFB366",
26
+ "STREET": "#B366FF",
27
+ "POSTAL_CODE": "#66FFB3",
28
+ "EMAIL": "#66B3FF",
29
+ "DATE": "#FFB3B3",
30
+ "CC_PROVIDER": "#B3FFB3",
31
+ }
32
+
33
+ EXAMPLE_SENTENCES = [
34
+ "שם מלא: תלמה אריאלי מספר תעודת זהות: 61453324-8 תאריך לידה: 15/09/1983 כתובת: ארלוזורוב 22 פתח תקווה מיקוד 2731711 אימייל: [email protected] טלפון: 054-8884771 בפגישה זו נדונו פתרונות טכנולוגיים חדשניים לשיפור תהליכי עבודה. המשתתף יתבקש להציג מצגת בנושא בפגישה הבאה אשר שילם ב 5326-1003-5299-5478 מסטרקארד עם הוראת קבע ל 11-77-352300",
35
+ ]
36
+
37
+ MODEL_DETAILS = {
38
+ "name": "GolemPII - Hebrew PII Detection Model CordwainerSmith/GolemPII-v7-full",
39
+ "description": "This on-premise PII model is designed to automatically identify and mask sensitive information (PII) within Hebrew text data. It has been trained to recognize a wide range of PII entities, including names, addresses, phone numbers, financial information, and more.",
40
+ "base_model": "microsoft/mdeberta-v3-base",
41
+ "training_data": "Custom Hebrew PII dataset (size not specified)",
42
+ "detected_pii_entities": [
43
+ "FIRST_NAME",
44
+ "LAST_NAME",
45
+ "STREET",
46
+ "CITY",
47
+ "PHONE_NUM",
48
+ "EMAIL",
49
+ "ID_NUM",
50
+ "BANK_ACCOUNT_NUM",
51
+ "CC_NUM",
52
+ "CC_PROVIDER",
53
+ "DATE",
54
+ "POSTAL_CODE",
55
+ ],
56
+ "training_details": {
57
+ "Training epochs": "5",
58
+ "Batch size": "32",
59
+ "Learning rate": "5e-5",
60
+ "Weight decay": "0.01",
61
+ "Training speed": "~2.19 it/s",
62
+ "Total training time": "2:08:26",
63
+ },
64
+ }
65
+
66
+
67
+ class PIIMaskingModel:
68
+ def __init__(self, model_name: str):
69
+ self.model_name = model_name
70
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
71
+ self.model = AutoModelForTokenClassification.from_pretrained(
72
+ model_name, token=HF_TOKEN
73
+ )
74
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ self.model.to(self.device)
76
+ self.model.eval()
77
+
78
+ def process_text(
79
+ self, text: str
80
+ ) -> Tuple[str, float, str, List[str], List[str], List[Dict]]:
81
+ start_time = time.time()
82
+
83
+ tokenized_inputs = self.tokenizer(
84
+ text,
85
+ truncation=True,
86
+ padding=False,
87
+ return_tensors="pt",
88
+ return_offsets_mapping=True,
89
+ add_special_tokens=True,
90
+ )
91
+
92
+ input_ids = tokenized_inputs.input_ids.to(self.device)
93
+ attention_mask = tokenized_inputs.attention_mask.to(self.device)
94
+ offset_mapping = tokenized_inputs["offset_mapping"][0].tolist()
95
+
96
+ # Handle special tokens
97
+ offset_mapping[0] = None # <s> token
98
+ offset_mapping[-1] = None # </s> token
99
+
100
+ with torch.no_grad():
101
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
102
+
103
+ predictions = outputs.logits.argmax(dim=-1).cpu().numpy()
104
+ predicted_labels = [
105
+ self.model.config.id2label[label_id] for label_id in predictions[0]
106
+ ]
107
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
108
+
109
+ masked_text, colored_text, privacy_masks = self.mask_pii_in_sentence(
110
+ tokens, predicted_labels, text, offset_mapping
111
+ )
112
+ processing_time = time.time() - start_time
113
+
114
+ return (
115
+ masked_text,
116
+ processing_time,
117
+ colored_text,
118
+ tokens,
119
+ predicted_labels,
120
+ privacy_masks,
121
+ )
122
+
123
+ def _find_entity_span(
124
+ self,
125
+ i: int,
126
+ labels: List[str],
127
+ tokens: List[str],
128
+ offset_mapping: List[Tuple[int, int]],
129
+ ) -> Tuple[int, str, int]:
130
+ """Find the end index and entity type for a span starting at index i"""
131
+ current_entity = labels[i][2:] if labels[i].startswith("B-") else labels[i][2:]
132
+ j = i + 1
133
+ last_valid_end = offset_mapping[i][1] if offset_mapping[i] else None
134
+
135
+ while j < len(tokens):
136
+ if offset_mapping[j] is None:
137
+ j += 1
138
+ continue
139
+
140
+ next_label = labels[j]
141
+
142
+ # Stop if we hit a new B- tag (except for non-spaced tokens)
143
+ if next_label.startswith("B-") and tokens[j].startswith("▁"):
144
+ break
145
+
146
+ # Stop if we hit a different entity type in I- tags
147
+ if next_label.startswith("I-") and next_label[2:] != current_entity:
148
+ break
149
+
150
+ # Continue if it's a continuation of the same entity
151
+ if next_label.startswith("I-") and next_label[2:] == current_entity:
152
+ last_valid_end = offset_mapping[j][1]
153
+ j += 1
154
+ # Continue if it's a non-spaced B- token
155
+ elif next_label.startswith("B-") and not tokens[j].startswith("▁"):
156
+ last_valid_end = offset_mapping[j][1]
157
+ j += 1
158
+ else:
159
+ break
160
+
161
+ return j, current_entity, last_valid_end
162
+
163
+ def mask_pii_in_sentence(
164
+ self,
165
+ tokens: List[str],
166
+ labels: List[str],
167
+ original_text: str,
168
+ offset_mapping: List[Tuple[int, int]],
169
+ ) -> Tuple[str, str, List[Dict]]:
170
+ privacy_masks = []
171
+ current_pos = 0
172
+ masked_text_parts = []
173
+ colored_text_parts = []
174
+
175
+ i = 0
176
+ while i < len(tokens):
177
+ if offset_mapping[i] is None: # Skip special tokens
178
+ i += 1
179
+ continue
180
+
181
+ current_label = labels[i]
182
+
183
+ if current_label.startswith(("B-", "I-")):
184
+ start_char = offset_mapping[i][0]
185
+
186
+ # Find the complete entity span
187
+ next_pos, entity_type, last_valid_end = self._find_entity_span(
188
+ i, labels, tokens, offset_mapping
189
+ )
190
+
191
+ # Add any text before the entity
192
+ if current_pos < start_char:
193
+ text_before = original_text[current_pos:start_char]
194
+ masked_text_parts.append(text_before)
195
+ colored_text_parts.append(text_before)
196
+
197
+ # Extract and mask the entity
198
+ entity_value = original_text[start_char:last_valid_end]
199
+ mask = self._get_mask_for_entity(entity_type)
200
+
201
+ # Add to privacy masks
202
+ privacy_masks.append(
203
+ {
204
+ "label": entity_type,
205
+ "start": start_char,
206
+ "end": last_valid_end,
207
+ "value": entity_value,
208
+ "label_index": len(privacy_masks) + 1,
209
+ }
210
+ )
211
+
212
+ # Add masked text
213
+ masked_text_parts.append(mask)
214
+
215
+ # Add colored text
216
+ color = ENTITY_COLORS.get(entity_type, "#CCCCCC")
217
+ colored_text_parts.append(
218
+ f'<span style="background-color: {color}; padding: 2px; border-radius: 3px;">{mask}</span>'
219
+ )
220
+
221
+ current_pos = last_valid_end
222
+ i = next_pos
223
+ else:
224
+ if offset_mapping[i] is not None:
225
+ start_char = offset_mapping[i][0]
226
+ end_char = offset_mapping[i][1]
227
+
228
+ # Add any text for this token
229
+ if current_pos < end_char:
230
+ text_chunk = original_text[current_pos:end_char]
231
+ masked_text_parts.append(text_chunk)
232
+ colored_text_parts.append(text_chunk)
233
+ current_pos = end_char
234
+ i += 1
235
+
236
+ # Add any remaining text
237
+ if current_pos < len(original_text):
238
+ remaining_text = original_text[current_pos:]
239
+ masked_text_parts.append(remaining_text)
240
+ colored_text_parts.append(remaining_text)
241
+
242
+ return ("".join(masked_text_parts), "".join(colored_text_parts), privacy_masks)
243
+
244
+ def _get_mask_for_entity(self, entity_type: str) -> str:
245
+ """Get the mask text for a given entity type"""
246
+ return {
247
+ "PHONE_NUM": "[טלפון]",
248
+ "ID_NUM": "[ת.ז]",
249
+ "CC_NUM": "[כרטיס אשראי]",
250
+ "BANK_ACCOUNT_NUM": "[חשבון בנק]",
251
+ "FIRST_NAME": "[שם פרטי]",
252
+ "LAST_NAME": "[שם משפחה]",
253
+ "CITY": "[עיר]",
254
+ "STREET": "[רחוב]",
255
+ "POSTAL_CODE": "[מיקוד]",
256
+ "EMAIL": "[אימייל]",
257
+ "DATE": "[תאריך]",
258
+ "CC_PROVIDER": "[ספק כרטיס אשראי]",
259
+ "BANK": "[בנק]",
260
+ }.get(entity_type, f"[{entity_type}]")
261
+
262
+
263
+ def save_results_to_file(results: Dict):
264
+ """
265
+ Save processing results to a JSON file
266
+ """
267
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
268
+ filename = f"pii_masking_results_{timestamp}.json"
269
+
270
+ with open(filename, "w", encoding="utf-8") as f:
271
+ json.dump(results, f, ensure_ascii=False, indent=2)
272
+
273
+ return filename
274
+
275
+
276
+ def main():
277
+ st.set_page_config(layout="wide")
278
+ st.title("🗿 GolemPII: Hebrew PII Masking Application 🗿")
279
+
280
+ # Add CSS styles
281
+ st.markdown(
282
+ """
283
+ <style>
284
+ .rtl { direction: rtl; text-align: right; }
285
+ .entity-legend { padding: 5px; margin: 2px; border-radius: 3px; display: inline-block; }
286
+ .masked-text {
287
+ direction: rtl;
288
+ text-align: right;
289
+ line-height: 2;
290
+ padding: 10px;
291
+ background-color: #f6f8fa;
292
+ border-radius: 5px;
293
+ color: black;
294
+ white-space: pre-wrap;
295
+ }
296
+ /* Red headers for sections */
297
+ .main h3 {
298
+ color: #d73a49;
299
+ margin-bottom: 10px;
300
+ }
301
+ /* Styles for the model details sidebar */
302
+ .model-details-sidebar h2 {
303
+ margin-top: 0;
304
+ }
305
+ .model-details-sidebar table {
306
+ width: 100%;
307
+ border-collapse: collapse;
308
+ }
309
+ .model-details-sidebar td, .model-details-sidebar th {
310
+ padding: 8px;
311
+ border: 1px solid #ddd;
312
+ text-align: left;
313
+ }
314
+ </style>
315
+ """,
316
+ unsafe_allow_html=True,
317
+ )
318
+
319
+ # Sidebar configuration
320
+ st.sidebar.header("Configuration")
321
+ selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
322
+ show_json = st.sidebar.checkbox("Show JSON Output", value=True)
323
+ run_all_models = st.sidebar.checkbox("Run All Models")
324
+
325
+ # Display Model Details in Sidebar
326
+ st.sidebar.markdown(
327
+ f"""
328
+ <div class="model-details-sidebar">
329
+ <h2>Model Details: {MODEL_DETAILS['name']}</h2>
330
+ <p>{MODEL_DETAILS['description']}</p>
331
+ <table>
332
+ <tr><td>Base Model:</td><td>{MODEL_DETAILS['base_model']}</td></tr>
333
+ <tr><td>Training Data:</td><td>{MODEL_DETAILS['training_data']}</td></tr>
334
+ </table>
335
+ <h3>Detected PII Entities</h3>
336
+ <ul>
337
+ {" ".join([f'<li><span class="entity-badge" style="background-color: {ENTITY_COLORS.get(entity, "#CCCCCC")}; padding: 3px 5px; border-radius: 3px; margin-right: 5px;">{entity}</span></li>' for entity in MODEL_DETAILS['detected_pii_entities']])}
338
+ </ul>
339
+ </div>
340
+ """,
341
+ unsafe_allow_html=True,
342
+ )
343
+
344
+ # Text input
345
+ text_input = st.text_area(
346
+ "Enter text to mask (separate multiple texts with commas):",
347
+ value="\n".join(EXAMPLE_SENTENCES),
348
+ height=200,
349
+ )
350
+
351
+ # Process button
352
+ if st.button("Process Text"):
353
+ texts = [text.strip() for text in text_input.split(",") if text.strip()]
354
+
355
+ if run_all_models:
356
+ all_results = {}
357
+ progress_bar = st.progress(0)
358
+
359
+ for idx, (model_name, model_path) in enumerate(MODELS.items()):
360
+ st.subheader(f"Results for {model_name}")
361
+ model = PIIMaskingModel(model_path)
362
+ model_results = {}
363
+
364
+ for text_idx, text in enumerate(texts):
365
+ (
366
+ masked_text,
367
+ processing_time,
368
+ colored_text,
369
+ tokens,
370
+ predicted_labels,
371
+ privacy_masks,
372
+ ) = model.process_text(text)
373
+ model_results[f"text_{text_idx+1}"] = {
374
+ "original": text,
375
+ "masked": masked_text,
376
+ "processing_time": processing_time,
377
+ "privacy_mask": privacy_masks,
378
+ "span_labels": [
379
+ [m["start"], m["end"], m["label"]] for m in privacy_masks
380
+ ],
381
+ }
382
+
383
+ all_results[model_name] = model_results
384
+ progress_bar.progress((idx + 1) / len(MODELS))
385
+
386
+ # Save and display results
387
+ filename = save_results_to_file(all_results)
388
+ st.success(f"Results saved to {filename}")
389
+
390
+ # Show comparison table
391
+ comparison_data = []
392
+ for model_name, results in all_results.items():
393
+ avg_time = sum(
394
+ text_data["processing_time"] for text_data in results.values()
395
+ ) / len(results)
396
+ comparison_data.append(
397
+ {"Model": model_name, "Avg Processing Time": f"{avg_time:.3f}s"}
398
+ )
399
+
400
+ st.subheader("Model Comparison")
401
+ st.table(pd.DataFrame(comparison_data))
402
+
403
+ else:
404
+ # Process with single selected model
405
+ model = PIIMaskingModel(MODELS[selected_model])
406
+
407
+ for text in texts:
408
+ st.markdown("### Original Text", unsafe_allow_html=True)
409
+ st.markdown(f'<div class="rtl">{text}</div>', unsafe_allow_html=True)
410
+
411
+ (
412
+ masked_text,
413
+ processing_time,
414
+ colored_text,
415
+ tokens,
416
+ predicted_labels,
417
+ privacy_masks,
418
+ ) = model.process_text(text)
419
+
420
+ st.markdown("### Masked Text", unsafe_allow_html=True)
421
+ st.markdown(
422
+ f'<div class="masked-text">{colored_text}</div>',
423
+ unsafe_allow_html=True,
424
+ )
425
+
426
+ st.markdown(f"Processing Time: {processing_time:.3f} seconds")
427
+
428
+ if show_json:
429
+ st.json(
430
+ {
431
+ "original": text,
432
+ "masked": masked_text,
433
+ "processing_time": processing_time,
434
+ "tokens": tokens,
435
+ "token_classes": predicted_labels,
436
+ "privacy_mask": privacy_masks,
437
+ "span_labels": [
438
+ [m["start"], m["end"], m["label"]]
439
+ for m in privacy_masks
440
+ ],
441
+ }
442
+ )
443
+
444
+ st.markdown("---")
445
+
446
+
447
+ if __name__ == "__main__":
448
+ main()