|
import gensim |
|
import re |
|
from concrete.ml.deployment import FHEModelClient, FHEModelServer |
|
from pathlib import Path |
|
from concrete.ml.common.serialization.loaders import load |
|
import uuid |
|
import json |
|
from transformers import AutoTokenizer, AutoModel |
|
from utils_demo import get_batch_text_representation |
|
|
|
base_dir = Path(__file__).parent |
|
|
|
|
|
class FHEAnonymizer: |
|
def __init__(self, punctuation_list=".,!?:;"): |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2") |
|
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2") |
|
|
|
self.punctuation_list = punctuation_list |
|
|
|
with open(base_dir / "original_document_uuid_mapping.json", 'r') as file: |
|
self.uuid_map = json.load(file) |
|
|
|
path_to_model = (base_dir / "deployment").resolve() |
|
self.client = FHEModelClient(path_to_model) |
|
self.server = FHEModelServer(path_to_model) |
|
self.client.generate_private_and_evaluation_keys() |
|
self.evaluation_key = self.client.get_serialized_evaluation_keys() |
|
|
|
def fhe_inference(self, x): |
|
enc_x = self.client.quantize_encrypt_serialize(x) |
|
enc_y = self.server.run(enc_x, self.evaluation_key) |
|
y = self.client.deserialize_decrypt_dequantize(enc_y) |
|
return y |
|
|
|
def __call__(self, text: str): |
|
|
|
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)" |
|
tokens = re.findall(token_pattern, text) |
|
identified_words_with_prob = [] |
|
processed_tokens = [] |
|
|
|
for token in tokens: |
|
|
|
if not token.strip() or not re.match(r"\w+", token): |
|
processed_tokens.append(token) |
|
continue |
|
|
|
|
|
x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer) |
|
|
|
prediction_proba = self.fhe_inference(x) |
|
probability = prediction_proba[0][1] |
|
|
|
if probability >= 0.5: |
|
identified_words_with_prob.append((token, probability)) |
|
|
|
|
|
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8]) |
|
processed_tokens.append(tmp_uuid) |
|
self.uuid_map[token] = tmp_uuid |
|
else: |
|
processed_tokens.append(token) |
|
|
|
|
|
with open(base_dir / "original_document_uuid_mapping.json", 'w') as file: |
|
json.dump(self.uuid_map, file) |
|
|
|
|
|
reconstructed_sentence = ''.join(processed_tokens) |
|
return reconstructed_sentence, identified_words_with_prob |
|
|