import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
import time
import json
import pandas as pd
from datetime import datetime
import os
from typing import List, Dict, Tuple
import re
# Constants
MODELS = {
"GolemPII XLM-RoBERTa v1": "CordwainerSmith/GolemPII-xlm-roberta-v1",
}
ENTITY_COLORS = {
"PHONE_NUM": "#FF9999",
"ID_NUM": "#99FF99",
"CC_NUM": "#9999FF",
"BANK_ACCOUNT_NUM": "#FFFF99",
"FIRST_NAME": "#FF99FF",
"LAST_NAME": "#99FFFF",
"CITY": "#FFB366",
"STREET": "#B366FF",
"POSTAL_CODE": "#66FFB3",
"EMAIL": "#66B3FF",
"DATE": "#FFB3B3",
"CC_PROVIDER": "#B3FFB3",
}
EXAMPLE_SENTENCES = [
"שם מלא: תלמה אריאלי מספר תעודת זהות: 61453324-8 תאריך לידה: 15/09/1983 כתובת: ארלוזורוב 22 פתח תקווה מיקוד 2731711 אימייל: mihailbenavi@ebox.co.il טלפון: 054-8884771 בפגישה זו נדונו פתרונות טכנולוגיים חדשניים לשיפור תהליכי עבודה. המשתתף יתבקש להציג מצגת בנושא בפגישה הבאה אשר שילם ב 5326-1003-5299-5478 מסטרקארד עם הוראת קבע ל 11-77-352300",
]
MODEL_DETAILS = {
"name": "GolemPII - Hebrew PII Detection Model CordwainerSmith/GolemPII-v7-full",
"description": "This on-premise PII model is designed to automatically identify and mask sensitive information (PII) within Hebrew text data. It has been trained to recognize a wide range of PII entities, including names, addresses, phone numbers, financial information, and more.",
"base_model": "microsoft/mdeberta-v3-base",
"training_data": "Custom Hebrew PII dataset (size not specified)",
"detected_pii_entities": [
"FIRST_NAME",
"LAST_NAME",
"STREET",
"CITY",
"PHONE_NUM",
"EMAIL",
"ID_NUM",
"BANK_ACCOUNT_NUM",
"CC_NUM",
"CC_PROVIDER",
"DATE",
"POSTAL_CODE",
],
"training_details": {
"Training epochs": "5",
"Batch size": "32",
"Learning rate": "5e-5",
"Weight decay": "0.01",
"Training speed": "~2.19 it/s",
"Total training time": "2:08:26",
},
}
class PIIMaskingModel:
def __init__(self, model_name: str):
self.model_name = model_name
hf_token = st.secrets["HF_TOKEN"] # Retrieve the token from secrets
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=hf_token
)
self.model = AutoModelForTokenClassification.from_pretrained(
model_name, use_auth_token=hf_token
)
def process_text(
self, text: str
) -> Tuple[str, float, str, List[str], List[str], List[Dict]]:
start_time = time.time()
tokenized_inputs = self.tokenizer(
text,
truncation=True,
padding=False,
return_tensors="np", # Return NumPy arrays for CPU
return_offsets_mapping=True,
add_special_tokens=True,
)
input_ids = tokenized_inputs.input_ids
attention_mask = tokenized_inputs.attention_mask
offset_mapping = tokenized_inputs["offset_mapping"][0].tolist()
# Handle special tokens
offset_mapping[0] = None # token
offset_mapping[-1] = None # token
# No need for torch.no_grad() as we are not using gradients
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=-1) # No need to move to CPU
predicted_labels = [
self.model.config.id2label[label_id] for label_id in predictions[0]
]
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
masked_text, colored_text, privacy_masks = self.mask_pii_in_sentence(
tokens, predicted_labels, text, offset_mapping
)
processing_time = time.time() - start_time
return (
masked_text,
processing_time,
colored_text,
tokens,
predicted_labels,
privacy_masks,
)
def _find_entity_span(
self,
i: int,
labels: List[str],
tokens: List[str],
offset_mapping: List[Tuple[int, int]],
) -> Tuple[int, str, int]:
"""Find the end index and entity type for a span starting at index i"""
current_entity = labels[i][2:] if labels[i].startswith("B-") else labels[i][2:]
j = i + 1
last_valid_end = offset_mapping[i][1] if offset_mapping[i] else None
while j < len(tokens):
if offset_mapping[j] is None:
j += 1
continue
next_label = labels[j]
# Stop if we hit a new B- tag (except for non-spaced tokens)
if next_label.startswith("B-") and tokens[j].startswith(" "):
break
# Stop if we hit a different entity type in I- tags
if next_label.startswith("I-") and next_label[2:] != current_entity:
break
# Continue if it's a continuation of the same entity
if next_label.startswith("I-") and next_label[2:] == current_entity:
last_valid_end = offset_mapping[j][1]
j += 1
# Continue if it's a non-spaced B- token
elif next_label.startswith("B-") and not tokens[j].startswith(" "):
last_valid_end = offset_mapping[j][1]
j += 1
else:
break
return j, current_entity, last_valid_end
def mask_pii_in_sentence(
self,
tokens: List[str],
labels: List[str],
original_text: str,
offset_mapping: List[Tuple[int, int]],
) -> Tuple[str, str, List[Dict]]:
privacy_masks = []
current_pos = 0
masked_text_parts = []
colored_text_parts = []
i = 0
while i < len(tokens):
if offset_mapping[i] is None: # Skip special tokens
i += 1
continue
current_label = labels[i]
if current_label.startswith(("B-", "I-")):
start_char = offset_mapping[i][0]
# Find the complete entity span
next_pos, entity_type, last_valid_end = self._find_entity_span(
i, labels, tokens, offset_mapping
)
# Add any text before the entity
if current_pos < start_char:
text_before = original_text[current_pos:start_char]
masked_text_parts.append(text_before)
colored_text_parts.append(text_before)
# Extract and mask the entity
entity_value = original_text[start_char:last_valid_end]
mask = self._get_mask_for_entity(entity_type)
# Add to privacy masks
privacy_masks.append(
{
"label": entity_type,
"start": start_char,
"end": last_valid_end,
"value": entity_value,
"label_index": len(privacy_masks) + 1,
}
)
# Add masked text
masked_text_parts.append(mask)
# Add colored text
color = ENTITY_COLORS.get(entity_type, "#CCCCCC")
colored_text_parts.append(
f'{mask}'
)
current_pos = last_valid_end
i = next_pos
else:
if offset_mapping[i] is not None:
start_char = offset_mapping[i][0]
end_char = offset_mapping[i][1]
# Add any text for this token
if current_pos < end_char:
text_chunk = original_text[current_pos:end_char]
masked_text_parts.append(text_chunk)
colored_text_parts.append(text_chunk)
current_pos = end_char
i += 1
# Add any remaining text
if current_pos < len(original_text):
remaining_text = original_text[current_pos:]
masked_text_parts.append(remaining_text)
colored_text_parts.append(remaining_text)
return ("".join(masked_text_parts), "".join(colored_text_parts), privacy_masks)
def _get_mask_for_entity(self, entity_type: str) -> str:
"""Get the mask text for a given entity type"""
return {
"PHONE_NUM": "[טלפון]",
"ID_NUM": "[ת.ז]",
"CC_NUM": "[כרטיס אשראי]",
"BANK_ACCOUNT_NUM": "[חשבון בנק]",
"FIRST_NAME": "[שם פרטי]",
"LAST_NAME": "[שם משפחה]",
"CITY": "[עיר]",
"STREET": "[רחוב]",
"POSTAL_CODE": "[מיקוד]",
"EMAIL": "[אימייל]",
"DATE": "[תאריך]",
"CC_PROVIDER": "[ספק כרטיס אשראי]",
"BANK": "[בנק]",
}.get(entity_type, f"[{entity_type}]")
def save_results_to_file(results: Dict):
"""
Save processing results to a JSON file
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"pii_masking_results_{timestamp}.json"
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
return filename
def main():
st.set_page_config(layout="wide")
st.title("🗿 GolemPII: Hebrew PII Masking Application 🗿")
# Add CSS styles
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Sidebar configuration
st.sidebar.header("Configuration")
selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
show_json = st.sidebar.checkbox("Show JSON Output", value=True)
run_all_models = st.sidebar.checkbox("Run All Models")
# Display Model Details in Sidebar
st.sidebar.markdown(
f"""