Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import time | |
import json | |
import pandas as pd | |
from datetime import datetime | |
import os | |
from typing import List, Dict, Tuple | |
import re | |
# Constants | |
MODELS = { | |
"GolemPII XLM-RoBERTa v1": "CordwainerSmith/GolemPII-xlm-roberta-v1", | |
} | |
ENTITY_COLORS = { | |
"PHONE_NUM": "#FF9999", | |
"ID_NUM": "#99FF99", | |
"CC_NUM": "#9999FF", | |
"BANK_ACCOUNT_NUM": "#FFFF99", | |
"FIRST_NAME": "#FF99FF", | |
"LAST_NAME": "#99FFFF", | |
"CITY": "#FFB366", | |
"STREET": "#B366FF", | |
"POSTAL_CODE": "#66FFB3", | |
"EMAIL": "#66B3FF", | |
"DATE": "#FFB3B3", | |
"CC_PROVIDER": "#B3FFB3", | |
} | |
EXAMPLE_SENTENCES = [ | |
"砖诐 诪诇讗: 转诇诪讛 讗专讬讗诇讬 诪住驻专 转注讜讚转 讝讛讜转: 61453324-8 转讗专讬讱 诇讬讚讛: 15/09/1983 讻转讜讘转: 讗专诇讜讝讜专讜讘 22 驻转讞 转拽讜讜讛 诪讬拽讜讚 2731711 讗讬诪讬讬诇: [email protected] 讟诇驻讜谉: 054-8884771 讘驻讙讬砖讛 讝讜 谞讚讜谞讜 驻转专讜谞讜转 讟讻谞讜诇讜讙讬讬诐 讞讚砖谞讬讬诐 诇砖讬驻讜专 转讛诇讬讻讬 注讘讜讚讛. 讛诪砖转转祝 讬转讘拽砖 诇讛爪讬讙 诪爪讙转 讘谞讜砖讗 讘驻讙讬砖讛 讛讘讗讛 讗砖专 砖讬诇诐 讘 5326-1003-5299-5478 诪住讟专拽讗专讚 注诐 讛讜专讗转 拽讘注 诇 11-77-352300", | |
] | |
MODEL_DETAILS = { | |
"name": "GolemPII - Hebrew PII Detection Model CordwainerSmith/GolemPII-v7-full", | |
"description": "This on-premise PII model is designed to automatically identify and mask sensitive information (PII) within Hebrew text data. It has been trained to recognize a wide range of PII entities, including names, addresses, phone numbers, financial information, and more.", | |
"base_model": "microsoft/mdeberta-v3-base", | |
"training_data": "Custom Hebrew PII dataset (size not specified)", | |
"detected_pii_entities": [ | |
"FIRST_NAME", | |
"LAST_NAME", | |
"STREET", | |
"CITY", | |
"PHONE_NUM", | |
"EMAIL", | |
"ID_NUM", | |
"BANK_ACCOUNT_NUM", | |
"CC_NUM", | |
"CC_PROVIDER", | |
"DATE", | |
"POSTAL_CODE", | |
], | |
"training_details": { | |
"Training epochs": "5", | |
"Batch size": "32", | |
"Learning rate": "5e-5", | |
"Weight decay": "0.01", | |
"Training speed": "~2.19 it/s", | |
"Total training time": "2:08:26", | |
}, | |
} | |
class PIIMaskingModel: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
hf_token = st.secrets["HF_TOKEN"] # Retrieve the token from secrets | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, use_auth_token=hf_token | |
) | |
self.model = AutoModelForTokenClassification.from_pretrained( | |
model_name, use_auth_token=hf_token | |
) | |
def process_text( | |
self, text: str | |
) -> Tuple[str, float, str, List[str], List[str], List[Dict]]: | |
start_time = time.time() | |
tokenized_inputs = self.tokenizer( | |
text, | |
truncation=True, | |
padding=False, | |
return_tensors="np", # Return NumPy arrays for CPU | |
return_offsets_mapping=True, | |
add_special_tokens=True, | |
) | |
input_ids = tokenized_inputs.input_ids | |
attention_mask = tokenized_inputs.attention_mask | |
offset_mapping = tokenized_inputs["offset_mapping"][0].tolist() | |
# Handle special tokens | |
offset_mapping[0] = None # <s> token | |
offset_mapping[-1] = None # </s> token | |
# No need for torch.no_grad() as we are not using gradients | |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = outputs.logits.argmax(dim=-1) # No need to move to CPU | |
predicted_labels = [ | |
self.model.config.id2label[label_id] for label_id in predictions[0] | |
] | |
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) | |
masked_text, colored_text, privacy_masks = self.mask_pii_in_sentence( | |
tokens, predicted_labels, text, offset_mapping | |
) | |
processing_time = time.time() - start_time | |
return ( | |
masked_text, | |
processing_time, | |
colored_text, | |
tokens, | |
predicted_labels, | |
privacy_masks, | |
) | |
def _find_entity_span( | |
self, | |
i: int, | |
labels: List[str], | |
tokens: List[str], | |
offset_mapping: List[Tuple[int, int]], | |
) -> Tuple[int, str, int]: | |
"""Find the end index and entity type for a span starting at index i""" | |
current_entity = labels[i][2:] if labels[i].startswith("B-") else labels[i][2:] | |
j = i + 1 | |
last_valid_end = offset_mapping[i][1] if offset_mapping[i] else None | |
while j < len(tokens): | |
if offset_mapping[j] is None: | |
j += 1 | |
continue | |
next_label = labels[j] | |
# Stop if we hit a new B- tag (except for non-spaced tokens) | |
if next_label.startswith("B-") and tokens[j].startswith(" "): | |
break | |
# Stop if we hit a different entity type in I- tags | |
if next_label.startswith("I-") and next_label[2:] != current_entity: | |
break | |
# Continue if it's a continuation of the same entity | |
if next_label.startswith("I-") and next_label[2:] == current_entity: | |
last_valid_end = offset_mapping[j][1] | |
j += 1 | |
# Continue if it's a non-spaced B- token | |
elif next_label.startswith("B-") and not tokens[j].startswith(" "): | |
last_valid_end = offset_mapping[j][1] | |
j += 1 | |
else: | |
break | |
return j, current_entity, last_valid_end | |
def mask_pii_in_sentence( | |
self, | |
tokens: List[str], | |
labels: List[str], | |
original_text: str, | |
offset_mapping: List[Tuple[int, int]], | |
) -> Tuple[str, str, List[Dict]]: | |
privacy_masks = [] | |
current_pos = 0 | |
masked_text_parts = [] | |
colored_text_parts = [] | |
i = 0 | |
while i < len(tokens): | |
if offset_mapping[i] is None: # Skip special tokens | |
i += 1 | |
continue | |
current_label = labels[i] | |
if current_label.startswith(("B-", "I-")): | |
start_char = offset_mapping[i][0] | |
# Find the complete entity span | |
next_pos, entity_type, last_valid_end = self._find_entity_span( | |
i, labels, tokens, offset_mapping | |
) | |
# Add any text before the entity | |
if current_pos < start_char: | |
text_before = original_text[current_pos:start_char] | |
masked_text_parts.append(text_before) | |
colored_text_parts.append(text_before) | |
# Extract and mask the entity | |
entity_value = original_text[start_char:last_valid_end] | |
mask = self._get_mask_for_entity(entity_type) | |
# Add to privacy masks | |
privacy_masks.append( | |
{ | |
"label": entity_type, | |
"start": start_char, | |
"end": last_valid_end, | |
"value": entity_value, | |
"label_index": len(privacy_masks) + 1, | |
} | |
) | |
# Add masked text | |
masked_text_parts.append(mask) | |
# Add colored text | |
color = ENTITY_COLORS.get(entity_type, "#CCCCCC") | |
colored_text_parts.append( | |
f'<span style="background-color: {color}; padding: 2px; border-radius: 3px;">{mask}</span>' | |
) | |
current_pos = last_valid_end | |
i = next_pos | |
else: | |
if offset_mapping[i] is not None: | |
start_char = offset_mapping[i][0] | |
end_char = offset_mapping[i][1] | |
# Add any text for this token | |
if current_pos < end_char: | |
text_chunk = original_text[current_pos:end_char] | |
masked_text_parts.append(text_chunk) | |
colored_text_parts.append(text_chunk) | |
current_pos = end_char | |
i += 1 | |
# Add any remaining text | |
if current_pos < len(original_text): | |
remaining_text = original_text[current_pos:] | |
masked_text_parts.append(remaining_text) | |
colored_text_parts.append(remaining_text) | |
return ("".join(masked_text_parts), "".join(colored_text_parts), privacy_masks) | |
def _get_mask_for_entity(self, entity_type: str) -> str: | |
"""Get the mask text for a given entity type""" | |
return { | |
"PHONE_NUM": "[讟诇驻讜谉]", | |
"ID_NUM": "[转.讝]", | |
"CC_NUM": "[讻专讟讬住 讗砖专讗讬]", | |
"BANK_ACCOUNT_NUM": "[讞砖讘讜谉 讘谞拽]", | |
"FIRST_NAME": "[砖诐 驻专讟讬]", | |
"LAST_NAME": "[砖诐 诪砖驻讞讛]", | |
"CITY": "[注讬专]", | |
"STREET": "[专讞讜讘]", | |
"POSTAL_CODE": "[诪讬拽讜讚]", | |
"EMAIL": "[讗讬诪讬讬诇]", | |
"DATE": "[转讗专讬讱]", | |
"CC_PROVIDER": "[住驻拽 讻专讟讬住 讗砖专讗讬]", | |
"BANK": "[讘谞拽]", | |
}.get(entity_type, f"[{entity_type}]") | |
def save_results_to_file(results: Dict): | |
""" | |
Save processing results to a JSON file | |
""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"pii_masking_results_{timestamp}.json" | |
with open(filename, "w", encoding="utf-8") as f: | |
json.dump(results, f, ensure_ascii=False, indent=2) | |
return filename | |
def main(): | |
st.set_page_config(layout="wide") | |
st.title("馃椏 GolemPII: Hebrew PII Masking Application 馃椏") | |
# Add CSS styles | |
st.markdown( | |
""" | |
<style> | |
.rtl { direction: rtl; text-align: right; } | |
.entity-legend { padding: 5px; margin: 2px; border-radius: 3px; display: inline-block; } | |
.masked-text { | |
direction: rtl; | |
text-align: right; | |
line-height: 2; | |
padding: 10px; | |
background-color: #f6f8fa; | |
border-radius: 5px; | |
color: black; | |
white-space: pre-wrap; | |
} | |
/* Red headers for sections */ | |
.main h3 { | |
color: #d73a49; | |
margin-bottom: 10px; | |
} | |
/* Styles for the model details sidebar */ | |
.model-details-sidebar h2 { | |
margin-top: 0; | |
} | |
.model-details-sidebar table { | |
width: 100%; | |
border-collapse: collapse; | |
} | |
.model-details-sidebar td, .model-details-sidebar th { | |
padding: 8px; | |
border: 1px solid #ddd; | |
text-align: left; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Sidebar configuration | |
st.sidebar.header("Configuration") | |
selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys())) | |
show_json = st.sidebar.checkbox("Show JSON Output", value=True) | |
run_all_models = st.sidebar.checkbox("Run All Models") | |
# Display Model Details in Sidebar | |
st.sidebar.markdown( | |
f""" | |
<div class="model-details-sidebar"> | |
<h2>Model Details: {MODEL_DETAILS['name']}</h2> | |
<p>{MODEL_DETAILS['description']}</p> | |
<table> | |
<tr><td>Base Model:</td><td>{MODEL_DETAILS['base_model']}</td></tr> | |
<tr><td>Training Data:</td><td>{MODEL_DETAILS['training_data']}</td></tr> | |
</table> | |
<h3>Detected PII Entities</h3> | |
<ul> | |
{" ".join([f'<li><span class="entity-badge" style="background-color: {ENTITY_COLORS.get(entity, "#CCCCCC")}; padding: 3px 5px; border-radius: 3px; margin-right: 5px;">{entity}</span></li>' for entity in MODEL_DETAILS['detected_pii_entities']])} | |
</ul> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Text input | |
text_input = st.text_area( | |
"Enter text to mask (separate multiple texts with commas):", | |
value="\n".join(EXAMPLE_SENTENCES), | |
height=200, | |
) | |
# Process button | |
if st.button("Process Text"): | |
texts = [text.strip() for text in text_input.split(",") if text.strip()] | |
if run_all_models: | |
all_results = {} | |
progress_bar = st.progress(0) | |
for idx, (model_name, model_path) in enumerate(MODELS.items()): | |
st.subheader(f"Results for {model_name}") | |
model = PIIMaskingModel(model_path) | |
model_results = {} | |
for text_idx, text in enumerate(texts): | |
( | |
masked_text, | |
processing_time, | |
colored_text, | |
tokens, | |
predicted_labels, | |
privacy_masks, | |
) = model.process_text(text) | |
model_results[f"text_{text_idx+1}"] = { | |
"original": text, | |
"masked": masked_text, | |
"processing_time": processing_time, | |
"privacy_mask": privacy_masks, | |
"span_labels": [ | |
[m["start"], m["end"], m["label"]] for m in privacy_masks | |
], | |
} | |
all_results[model_name] = model_results | |
progress_bar.progress((idx + 1) / len(MODELS)) | |
# Save and display results | |
filename = save_results_to_file(all_results) | |
st.success(f"Results saved to {filename}") | |
# Show comparison table | |
comparison_data = [] | |
for model_name, results in all_results.items(): | |
avg_time = sum( | |
text_data["processing_time"] for text_data in results.values() | |
) / len(results) | |
comparison_data.append( | |
{"Model": model_name, "Avg Processing Time": f"{avg_time:.3f}s"} | |
) | |
st.subheader("Model Comparison") | |
st.table(pd.DataFrame(comparison_data)) | |
else: | |
# Process with single selected model | |
model = PIIMaskingModel(MODELS[selected_model]) | |
for text in texts: | |
st.markdown("### Original Text", unsafe_allow_html=True) | |
st.markdown(f'<div class="rtl">{text}</div>', unsafe_allow_html=True) | |
( | |
masked_text, | |
processing_time, | |
colored_text, | |
tokens, | |
predicted_labels, | |
privacy_masks, | |
) = model.process_text(text) | |
st.markdown("### Masked Text", unsafe_allow_html=True) | |
st.markdown( | |
f'<div class="masked-text">{colored_text}</div>', | |
unsafe_allow_html=True, | |
) | |
st.markdown(f"Processing Time: {processing_time:.3f} seconds") | |
if show_json: | |
st.json( | |
{ | |
"original": text, | |
"masked": masked_text, | |
"processing_time": processing_time, | |
"tokens": tokens, | |
"token_classes": predicted_labels, | |
"privacy_mask": privacy_masks, | |
"span_labels": [ | |
[m["start"], m["end"], m["label"]] | |
for m in privacy_masks | |
], | |
} | |
) | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |