Spaces:
Sleeping
Sleeping
"""A Gradio app for de-identifying text data using FHE.""" | |
import base64 | |
import os | |
import re | |
import subprocess | |
import time | |
import uuid | |
from typing import Dict, List | |
import gradio as gr | |
import numpy | |
import pandas as pd | |
import requests | |
from fhe_anonymizer import FHEAnonymizer | |
from utils_demo import * | |
from concrete.ml.deployment import FHEModelClient | |
from models.speech_to_text import * | |
from models.speech_to_text.transcriber.audio import preprocess_audio | |
from models.speech_to_text.transcriber.model import load_model_and_processor | |
from models.speech_to_text.transcriber.audio import transcribe_audio | |
# Ensure the directory is clean before starting processes or reading files | |
clean_directory() | |
anonymizer = FHEAnonymizer() | |
# Start the Uvicorn server hosting the FastAPI app | |
subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) | |
time.sleep(3) | |
# Load data from files required for the application | |
UUID_MAP = read_json(MAPPING_UUID_PATH) | |
MAPPING_DOC_EMBEDDING = read_pickle(MAPPING_DOC_EMBEDDING_PATH) | |
# Generate a random user ID for this session | |
USER_ID = numpy.random.randint(0, 2**32) | |
def key_gen_fn() -> Dict: | |
"""Generate keys for a given user.""" | |
print("------------ Step 1: Key Generation:") | |
print(f"Your user ID is: {USER_ID}....") | |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}") | |
client.load() | |
# Creates the private and evaluation keys on the client side | |
client.generate_private_and_evaluation_keys() | |
# Get the serialized evaluation keys | |
serialized_evaluation_keys = client.get_serialized_evaluation_keys() | |
assert isinstance(serialized_evaluation_keys, bytes) | |
# Save the evaluation key | |
evaluation_key_path = KEYS_DIR / f"{USER_ID}/evaluation_key" | |
write_bytes(evaluation_key_path, serialized_evaluation_keys) | |
if not evaluation_key_path.is_file(): | |
error_message = f"Error Encountered While generating the evaluation {evaluation_key_path.is_file()=}" | |
print(error_message) | |
return {gen_key_btn: gr.update(value=error_message)} | |
else: | |
print("Keys have been generated β ") | |
return {gen_key_btn: gr.update(value="Keys have been generated β ")} | |
def encrypt_query_fn(query): | |
print(f"\n------------ Step 2: Query encryption: {query=}") | |
if not (KEYS_DIR / f"{USER_ID}/evaluation_key").is_file(): | |
return {output_encrypted_box: gr.update(value="Error β: Please generate the key first!", lines=8)} | |
if is_user_query_valid(query): | |
return { | |
query_box: gr.update( | |
value="Unable to process β: The request exceeds the length limit or falls outside the scope. Please refine your query." | |
) | |
} | |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}") | |
client.load() | |
encrypted_tokens = [] | |
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", query) | |
for token in tokens: | |
if not bool(re.match(r"^\s+$", token)): | |
emb_x = get_batch_text_representation([token], EMBEDDINGS_MODEL, TOKENIZER) | |
encrypted_x = client.quantize_encrypt_serialize(emb_x) | |
assert isinstance(encrypted_x, bytes) | |
encrypted_tokens.append(encrypted_x) | |
print("Data encrypted β on Client Side") | |
assert len({len(token) for token in encrypted_tokens}) == 1 | |
write_bytes(KEYS_DIR / f"{USER_ID}/encrypted_input", b"".join(encrypted_tokens)) | |
write_bytes(KEYS_DIR / f"{USER_ID}/encrypted_input_len", len(encrypted_tokens[0]).to_bytes(10, "big")) | |
encrypted_quant_tokens_hex = [token.hex()[500:580] for token in encrypted_tokens] | |
return { | |
output_encrypted_box: gr.update(value=" ".join(encrypted_quant_tokens_hex), lines=8), | |
anonymized_query_output: gr.update(visible=True, value=None), | |
identified_words_output_df: gr.update(visible=False, value=None), | |
} | |
def send_input_fn(query) -> Dict: | |
print("------------ Step 3.1: Send encrypted_data to the Server") | |
evaluation_key_path = KEYS_DIR / f"{USER_ID}/evaluation_key" | |
encrypted_input_path = KEYS_DIR / f"{USER_ID}/encrypted_input" | |
encrypted_input_len_path = KEYS_DIR / f"{USER_ID}/encrypted_input_len" | |
if not evaluation_key_path.is_file() or not encrypted_input_path.is_file(): | |
error_message = "Error: Key or encrypted input not found. Please generate the key and encrypt the query first." | |
return {anonymized_query_output: gr.update(value=error_message)} | |
data = {"user_id": USER_ID, "input": query} | |
files = [ | |
("files", open(evaluation_key_path, "rb")), | |
("files", open(encrypted_input_path, "rb")), | |
("files", open(encrypted_input_len_path, "rb")), | |
] | |
url = SERVER_URL + "send_input" | |
with requests.post(url=url, data=data, files=files) as resp: | |
print("Data sent to the server β " if resp.ok else "Error β in sending data to the server") | |
def run_fhe_in_server_fn() -> Dict: | |
print("------------ Step 3.2: Run in FHE on the Server Side") | |
data = {"user_id": USER_ID} | |
url = SERVER_URL + "run_fhe" | |
with requests.post(url=url, data=data) as response: | |
if not response.ok: | |
return { | |
anonymized_query_output: gr.update( | |
value="β οΈ An error occurred on the Server Side. Please check connectivity and data transmission." | |
), | |
} | |
else: | |
time.sleep(1) | |
print(f"The query anonymization was computed in {response.json():.2f} s per token.") | |
def get_output_fn() -> Dict: | |
print("------------ Step 3.3: Get the output from the Server Side") | |
data = {"user_id": USER_ID} | |
url = SERVER_URL + "get_output" | |
with requests.post(url=url, data=data) as response: | |
if response.ok: | |
print("Data received β from the remote Server") | |
response_data = response.json() | |
encrypted_output = base64.b64decode(response_data["encrypted_output"]) | |
length_encrypted_output = base64.b64decode(response_data["length"]) | |
write_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output", encrypted_output) | |
write_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output_len", length_encrypted_output) | |
else: | |
print("Error β in getting data from the server") | |
def decrypt_fn(text) -> Dict: | |
print("------------ Step 4: Decrypt the data on the `Client Side`") | |
encrypted_output_path = CLIENT_DIR / f"{USER_ID}_encrypted_output" | |
if not encrypted_output_path.is_file(): | |
error_message = "β οΈ Error: Encrypted output not found. Please ensure the entire process has been completed." | |
print(error_message) | |
return error_message, None | |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{USER_ID}") | |
client.load() | |
encrypted_output = read_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output") | |
length = int.from_bytes(read_bytes(CLIENT_DIR / f"{USER_ID}_encrypted_output_len"), "big") | |
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text) | |
decrypted_output, identified_words_with_prob = [], [] | |
i = 0 | |
for token in tokens: | |
if not bool(re.match(r"^\s+$", token)): | |
encrypted_token = encrypted_output[i : i + length] | |
prediction_proba = client.deserialize_decrypt_dequantize(encrypted_token) | |
probability = prediction_proba[0][1] | |
i += length | |
if probability >= 0.77: | |
identified_words_with_prob.append((token, probability)) | |
tmp_uuid = UUID_MAP.get(token, str(uuid.uuid4())[:8]) | |
decrypted_output.append(tmp_uuid) | |
UUID_MAP[token] = tmp_uuid | |
else: | |
decrypted_output.append(token) | |
write_json(MAPPING_UUID_PATH, UUID_MAP) | |
anonymized_text = re.sub(r"\s([,.!?;:])", r"\1", " ".join(decrypted_output)) | |
identified_df = pd.DataFrame( | |
identified_words_with_prob, columns=["Identified Words", "Probability"] | |
) if identified_words_with_prob else pd.DataFrame(columns=["Identified Words", "Probability"]) | |
print("Decryption done β on Client Side") | |
return anonymized_text, identified_df | |
def anonymization_with_fn(query): | |
encrypt_query_fn(query) | |
send_input_fn(query) | |
run_fhe_in_server_fn() | |
get_output_fn() | |
anonymized_text, identified_df = decrypt_fn(query) | |
return { | |
anonymized_query_output: gr.update(value=anonymized_text), | |
identified_words_output_df: gr.update(value=identified_df, visible=True), | |
} | |
demo = gr.Blocks(css=".markdown-body { font-size: 18px; }") | |
with demo: | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;">Secure De-Identification of Text Data using FHE</h1> | |
""" | |
) | |
gr.Markdown( | |
""" | |
<p align="center" style="font-size: 18px;"> | |
This demo showcases privacy-preserving de-identification of text data using Fully Homomorphic Encryption (FHE). | |
</p> | |
""" | |
) | |
########################## Key Gen Part ########################## | |
gr.Markdown( | |
"## Step 1: Generate the keys\n\n" | |
"""In Fully Homomorphic Encryption (FHE) methods, two types of keys are created: secret keys for encrypting and decrypting user data, | |
and evaluation keys for the server to work on encrypted data without seeing the actual content.""" | |
) | |
gen_key_btn = gr.Button("Generate the secret and evaluation keys") | |
gen_key_btn.click(key_gen_fn, inputs=[], outputs=[gen_key_btn]) | |
########################## User Query Part ########################## | |
gr.Markdown("## Step 2: Enter the prompt you want to encrypt and de-identify") | |
query_box = gr.Textbox( | |
value="Hello. My name is John Doe. I live at 123 Main St, Anytown, USA.", | |
label="Enter your prompt:", | |
interactive=True | |
) | |
encrypt_query_btn = gr.Button("Encrypt the prompt") | |
output_encrypted_box = gr.Textbox( | |
label="Encrypted prompt (will be sent to the de-identification server):", | |
lines=4, | |
) | |
encrypt_query_btn.click( | |
fn=encrypt_query_fn, | |
inputs=[query_box], | |
outputs=[query_box, output_encrypted_box], | |
) | |
########################## FHE processing Part ########################## | |
gr.Markdown("## Step 3: De-identify the prompt using FHE") | |
gr.Markdown( | |
"""The encrypted prompt will be sent to a remote server for de-identification using FHE. | |
The server performs computations on the encrypted data and returns the result for decryption.""" | |
) | |
run_fhe_btn = gr.Button("De-identify using FHE") | |
anonymized_query_output = gr.Textbox( | |
label="De-identified prompt", lines=4, interactive=True | |
) | |
identified_words_output_df = gr.Dataframe(label="Identified words:", visible=False) | |
run_fhe_btn.click( | |
anonymization_with_fn, | |
inputs=[query_box], | |
outputs=[anonymized_query_output, identified_words_output_df], | |
) | |
# Launch the app | |
demo.launch(share=False) |