Zamanonymize3 / app.py
mzameshina's picture
Update app.py
8a75eb3 verified
raw
history blame
11.1 kB
"""A Gradio app for de-identifying text data using FHE."""
import base64
import os
import re
import subprocess
import time
import uuid
from typing import Dict, List
import gradio as gr
import numpy
import pandas as pd
import requests
from fhe_anonymizer import FHEAnonymizer
from utils_demo import *
from concrete.ml.deployment import FHEModelClient
from models.speech_to_text import *
from models.speech_to_text.transcriber.audio import preprocess_audio
from models.speech_to_text.transcriber.model import load_model_and_processor
from models.speech_to_text.transcriber.audio import transcribe_audio
# Ensure the directory is clean before starting processes or reading files
clean_directory()
anonymizer = FHEAnonymizer()
# Start the Uvicorn server hosting the FastAPI app
subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
time.sleep(3)
# Load data from files required for the application
UUID_MAP = read_json(MAPPING_UUID_PATH)
MAPPING_DOC_EMBEDDING = read_pickle(MAPPING_DOC_EMBEDDING_PATH)
# Generate a random user ID for this session
USER_ID = numpy.random.randint(0, 2**32)
def key_gen_fn() -> Dict:
"""Generate keys for a given user."""
print("------------ Step 1: Key Generation:")
print(f"Your user ID is: {USER_ID}....")
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{USER_ID}/evaluation_key"
write_bytes(evaluation_key_path, serialized_evaluation_keys)
if not evaluation_key_path.is_file():
error_message = f"Error Encountered While generating the evaluation {evaluation_key_path.is_file()=}"
print(error_message)
return {gen_key_btn: gr.update(value=error_message)}
else:
print("Keys have been generated βœ…")
return {gen_key_btn: gr.update(value="Keys have been generated βœ…")}
def encrypt_query_fn(query):
print(f"\n------------ Step 2: Query encryption: {query=}")
if not (KEYS_DIR / f"{USER_ID}/evaluation_key").is_file():
return {output_encrypted_box: gr.update(value="Error ❌: Please generate the key first!", lines=8)}
if is_user_query_valid(query):
return {
query_box: gr.update(
value="Unable to process ❌: The request exceeds the length limit or falls outside the scope. Please refine your query."
)
}
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}")
client.load()
encrypted_tokens = []
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", query)
for token in tokens:
if not bool(re.match(r"^\s+$", token)):
emb_x = get_batch_text_representation([token], EMBEDDINGS_MODEL, TOKENIZER)
encrypted_x = client.quantize_encrypt_serialize(emb_x)
assert isinstance(encrypted_x, bytes)
encrypted_tokens.append(encrypted_x)
print("Data encrypted βœ… on Client Side")
assert len({len(token) for token in encrypted_tokens}) == 1
write_bytes(KEYS_DIR / f"{USER_ID}/encrypted_input", b"".join(encrypted_tokens))
write_bytes(KEYS_DIR / f"{USER_ID}/encrypted_input_len", len(encrypted_tokens[0]).to_bytes(10, "big"))
encrypted_quant_tokens_hex = [token.hex()[500:580] for token in encrypted_tokens]
return {
output_encrypted_box: gr.update(value=" ".join(encrypted_quant_tokens_hex), lines=8),
anonymized_query_output: gr.update(visible=True, value=None),
identified_words_output_df: gr.update(visible=False, value=None),
}
def send_input_fn(query) -> Dict:
print("------------ Step 3.1: Send encrypted_data to the Server")
evaluation_key_path = KEYS_DIR / f"{USER_ID}/evaluation_key"
encrypted_input_path = KEYS_DIR / f"{USER_ID}/encrypted_input"
encrypted_input_len_path = KEYS_DIR / f"{USER_ID}/encrypted_input_len"
if not evaluation_key_path.is_file() or not encrypted_input_path.is_file():
error_message = "Error: Key or encrypted input not found. Please generate the key and encrypt the query first."
return {anonymized_query_output: gr.update(value=error_message)}
data = {"user_id": USER_ID, "input": query}
files = [
("files", open(evaluation_key_path, "rb")),
("files", open(encrypted_input_path, "rb")),
("files", open(encrypted_input_len_path, "rb")),
]
url = SERVER_URL + "send_input"
with requests.post(url=url, data=data, files=files) as resp:
print("Data sent to the server βœ…" if resp.ok else "Error ❌ in sending data to the server")
def run_fhe_in_server_fn() -> Dict:
print("------------ Step 3.2: Run in FHE on the Server Side")
data = {"user_id": USER_ID}
url = SERVER_URL + "run_fhe"
with requests.post(url=url, data=data) as response:
if not response.ok:
return {
anonymized_query_output: gr.update(
value="⚠️ An error occurred on the Server Side. Please check connectivity and data transmission."
),
}
else:
time.sleep(1)
print(f"The query anonymization was computed in {response.json():.2f} s per token.")
def get_output_fn() -> Dict:
print("------------ Step 3.3: Get the output from the Server Side")
data = {"user_id": USER_ID}
url = SERVER_URL + "get_output"
with requests.post(url=url, data=data) as response:
if response.ok:
print("Data received βœ… from the remote Server")
response_data = response.json()
encrypted_output = base64.b64decode(response_data["encrypted_output"])
length_encrypted_output = base64.b64decode(response_data["length"])
write_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output", encrypted_output)
write_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output_len", length_encrypted_output)
else:
print("Error ❌ in getting data from the server")
def decrypt_fn(text) -> Dict:
print("------------ Step 4: Decrypt the data on the `Client Side`")
encrypted_output_path = CLIENT_DIR / f"{USER_ID}_encrypted_output"
if not encrypted_output_path.is_file():
error_message = "⚠️ Error: Encrypted output not found. Please ensure the entire process has been completed."
print(error_message)
return error_message, None
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}")
client.load()
encrypted_output = read_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output")
length = int.from_bytes(read_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output_len"), "big")
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
decrypted_output, identified_words_with_prob = [], []
i = 0
for token in tokens:
if not bool(re.match(r"^\s+$", token)):
encrypted_token = encrypted_output[i : i + length]
prediction_proba = client.deserialize_decrypt_dequantize(encrypted_token)
probability = prediction_proba[0][1]
i += length
if probability >= 0.77:
identified_words_with_prob.append((token, probability))
tmp_uuid = UUID_MAP.get(token, str(uuid.uuid4())[:8])
decrypted_output.append(tmp_uuid)
UUID_MAP[token] = tmp_uuid
else:
decrypted_output.append(token)
write_json(MAPPING_UUID_PATH, UUID_MAP)
anonymized_text = re.sub(r"\s([,.!?;:])", r"\1", " ".join(decrypted_output))
identified_df = pd.DataFrame(
identified_words_with_prob, columns=["Identified Words", "Probability"]
) if identified_words_with_prob else pd.DataFrame(columns=["Identified Words", "Probability"])
print("Decryption done βœ… on Client Side")
return anonymized_text, identified_df
def anonymization_with_fn(query):
encrypt_query_fn(query)
send_input_fn(query)
run_fhe_in_server_fn()
get_output_fn()
anonymized_text, identified_df = decrypt_fn(query)
return {
anonymized_query_output: gr.update(value=anonymized_text),
identified_words_output_df: gr.update(value=identified_df, visible=True),
}
demo = gr.Blocks(css=".markdown-body { font-size: 18px; }")
with demo:
gr.Markdown(
"""
<h1 style="text-align: center;">Secure De-Identification of Text Data using FHE</h1>
"""
)
gr.Markdown(
"""
<p align="center" style="font-size: 18px;">
This demo showcases privacy-preserving de-identification of text data using Fully Homomorphic Encryption (FHE).
</p>
"""
)
########################## Key Gen Part ##########################
gr.Markdown(
"## Step 1: Generate the keys\n\n"
"""In Fully Homomorphic Encryption (FHE) methods, two types of keys are created: secret keys for encrypting and decrypting user data,
and evaluation keys for the server to work on encrypted data without seeing the actual content."""
)
gen_key_btn = gr.Button("Generate the secret and evaluation keys")
gen_key_btn.click(key_gen_fn, inputs=[], outputs=[gen_key_btn])
########################## User Query Part ##########################
gr.Markdown("## Step 2: Enter the prompt you want to encrypt and de-identify")
query_box = gr.Textbox(
value="Hello. My name is John Doe. I live at 123 Main St, Anytown, USA.",
label="Enter your prompt:",
interactive=True
)
encrypt_query_btn = gr.Button("Encrypt the prompt")
output_encrypted_box = gr.Textbox(
label="Encrypted prompt (will be sent to the de-identification server):",
lines=4,
)
encrypt_query_btn.click(
fn=encrypt_query_fn,
inputs=[query_box],
outputs=[query_box, output_encrypted_box],
)
########################## FHE processing Part ##########################
gr.Markdown("## Step 3: De-identify the prompt using FHE")
gr.Markdown(
"""The encrypted prompt will be sent to a remote server for de-identification using FHE.
The server performs computations on the encrypted data and returns the result for decryption."""
)
run_fhe_btn = gr.Button("De-identify using FHE")
anonymized_query_output = gr.Textbox(
label="De-identified prompt", lines=4, interactive=True
)
identified_words_output_df = gr.Dataframe(label="Identified words:", visible=False)
run_fhe_btn.click(
anonymization_with_fn,
inputs=[query_box],
outputs=[anonymized_query_output, identified_words_output_df],
)
# Launch the app
demo.launch(share=False)