Spaces:
Sleeping
Sleeping
"""A local gradio app that detects seizures with EEG using FHE.""" | |
from PIL import Image | |
import os | |
import shutil | |
import subprocess | |
import time | |
import gradio as gr | |
import numpy | |
import requests | |
from itertools import chain | |
from client_server_interface import FHEClient | |
import requests | |
from requests.adapters import HTTPAdapter | |
from requests.packages.urllib3.util.retry import Retry | |
import logging | |
from common import ( | |
CLIENT_TMP_PATH, | |
SERVER_TMP_PATH, | |
EXAMPLES, | |
INPUT_SHAPE, | |
KEYS_PATH, | |
REPO_DIR, | |
SERVER_URL, | |
) | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
def requests_retry_session( | |
retries=3, | |
backoff_factor=0.3, | |
status_forcelist=(500, 502, 504), | |
session=None, | |
): | |
session = session or requests.Session() | |
retry = Retry( | |
total=retries, | |
read=retries, | |
connect=retries, | |
backoff_factor=backoff_factor, | |
status_forcelist=status_forcelist, | |
) | |
adapter = HTTPAdapter(max_retries=retry) | |
session.mount('http://', adapter) | |
session.mount('https://', adapter) | |
return session | |
# Uncomment here to have both the server and client in the same terminal | |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) | |
time.sleep(3) | |
def shorten_bytes_object(bytes_object, limit=500): | |
"""Shorten the input bytes object to a given length. | |
Encrypted data is too large for displaying it in the browser using Gradio. This function | |
provides a shorten representation of it. | |
Args: | |
bytes_object (bytes): The input to shorten | |
limit (int): The length to consider. Default to 500. | |
Returns: | |
str: Hexadecimal string shorten representation of the input byte object. | |
""" | |
# Define a shift for better display | |
shift = 100 | |
return bytes_object[shift : limit + shift].hex() | |
def get_client(user_id): | |
"""Get the client API. | |
Args: | |
user_id (int): The current user's ID. | |
Returns: | |
FHEClient: The client API. | |
""" | |
return FHEClient( | |
key_dir=KEYS_PATH / f"seizure_detection_{user_id}" | |
) | |
def get_client_file_path(name, user_id): | |
"""Get the correct temporary file path for the client. | |
Args: | |
name (str): The desired file name. | |
user_id (int): The current user's ID. | |
Returns: | |
pathlib.Path: The file path. | |
""" | |
return CLIENT_TMP_PATH / f"{name}_seizure_detection_{user_id}" | |
def clean_temporary_files(n_keys=20): | |
"""Clean keys and encrypted images. | |
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this | |
limit is reached, the oldest files are deleted. | |
Args: | |
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. | |
""" | |
# Get the oldest key files in the key directory | |
key_dirs = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime) | |
# If more than n_keys keys are found, remove the oldest | |
user_ids = [] | |
if len(key_dirs) > n_keys: | |
n_keys_to_delete = len(key_dirs) - n_keys | |
for key_dir in key_dirs[:n_keys_to_delete]: | |
user_ids.append(key_dir.name) | |
shutil.rmtree(key_dir) | |
# Get all the encrypted objects in the temporary folder | |
client_files = CLIENT_TMP_PATH.iterdir() | |
server_files = SERVER_TMP_PATH.iterdir() | |
# Delete all files related to the ids whose keys were deleted | |
for file in chain(client_files, server_files): | |
for user_id in user_ids: | |
if user_id in file.name: | |
file.unlink() | |
def keygen(): | |
"""Generate the private key for seizure detection.""" | |
logger.info("Starting key generation process") | |
try: | |
# Clean temporary files | |
clean_temporary_files() | |
# Create an ID for the current user | |
user_id = numpy.random.randint(0, 2**32) | |
logger.info(f"Generated user_id: {user_id}") | |
# Retrieve the client API | |
client = get_client(user_id) | |
logger.info("Retrieved client API") | |
# Generate a private key | |
logger.info("Generating private and evaluation keys") | |
client.generate_private_and_evaluation_keys(force=True) | |
logger.info("Private and evaluation keys generated successfully") | |
# Retrieve the serialized evaluation key | |
logger.info("Retrieving serialized evaluation keys") | |
evaluation_key = client.get_serialized_evaluation_keys() | |
logger.info("Serialized evaluation keys retrieved") | |
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio | |
# buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
evaluation_key_path = get_client_file_path("evaluation_key", user_id) | |
logger.info(f"Saving evaluation key to: {evaluation_key_path}") | |
with evaluation_key_path.open("wb") as evaluation_key_file: | |
evaluation_key_file.write(evaluation_key) | |
logger.info("Evaluation key saved successfully") | |
return (user_id, True) | |
except Exception as e: | |
logger.error(f"Error during key generation: {str(e)}") | |
raise gr.Error(f"Key generation failed: {str(e)}") | |
def encrypt(user_id, input_image): | |
"""Encrypt the given image for seizure detection. | |
Args: | |
user_id (int): The current user's ID. | |
input_image (numpy.ndarray): The image to encrypt. | |
Returns: | |
(input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its | |
representation. | |
""" | |
if user_id == "": | |
raise gr.Error("Please generate the private key first.") | |
if input_image is None: | |
raise gr.Error("Please choose an image first.") | |
# Resize the image if it hasn't the shape (224, 224, 3) | |
if input_image.shape != (32, 32, 3): | |
input_image_pil = Image.fromarray(input_image) | |
input_image_pil = input_image_pil.resize((32, 32)) | |
input_image = numpy.array(input_image_pil) | |
# Convert RGB to grayscale | |
input_image_gray = numpy.mean(input_image, axis=2).astype(numpy.uint8) | |
# Reshape to (1, 1, 224, 224) | |
input_image_reshaped = input_image_gray.reshape(1, 1, 32, 32) | |
# Convert to int12 (assuming the range is 0-255, we can simply cast to int16) | |
input_image_int12 = input_image_reshaped.astype(numpy.int16) | |
# Retrieve the client API | |
client = get_client(user_id) | |
# Pre-process, encrypt and serialize the image | |
encrypted_image = client.encrypt_serialize(input_image_int12) | |
# Save encrypted_image to bytes in a file, since too large to pass through regular Gradio | |
# buttons, https://github.com/gradio-app/gradio/issues/1877 | |
encrypted_image_path = get_client_file_path("encrypted_image", user_id) | |
with encrypted_image_path.open("wb") as encrypted_image_file: | |
encrypted_image_file.write(encrypted_image) | |
# Create a truncated version of the encrypted image for display | |
encrypted_image_short = shorten_bytes_object(encrypted_image) | |
return (resize_img(input_image), encrypted_image_short) | |
def send_input(user_id): | |
"""Send the encrypted input image as well as the evaluation key to the server.""" | |
# Get the evaluation key path | |
evaluation_key_path = get_client_file_path("evaluation_key", user_id) | |
encrypted_input_path = get_client_file_path("encrypted_image", user_id) | |
if user_id == "" or not evaluation_key_path.is_file(): | |
raise gr.Error("Please generate the private key first.") | |
if not encrypted_input_path.is_file(): | |
raise gr.Error("Please generate the private key and then encrypt an image first.") | |
# Define the data and files to post | |
data = { | |
"user_id": user_id, | |
} | |
files = [ | |
("files", ("encrypted_image", open(encrypted_input_path, "rb"), "application/octet-stream")), | |
("files", ("evaluation_key", open(evaluation_key_path, "rb"), "application/octet-stream")), | |
] | |
logger.info(f"Sending encrypted_image from: {encrypted_input_path}") | |
logger.info(f"Sending evaluation_key from: {evaluation_key_path}") | |
# Send the encrypted input image and evaluation key to the server | |
url = SERVER_URL + "send_input" | |
with requests.post(url=url, data=data, files=files) as response: | |
return response.ok | |
def run_fhe(user_id): | |
"""Apply the seizure detection model on the encrypted image previously sent using FHE.""" | |
data = {"user_id": user_id} | |
url = SERVER_URL + "run_fhe" | |
try: | |
logger.info(f"Sending request to {url} with user_id: {user_id}") | |
with requests_retry_session().post(url=url, data=data, timeout=300) as response: | |
logger.info(f"Received response with status code: {response.status_code}") | |
response.raise_for_status() # Raises an HTTPError for bad responses | |
if response.ok: | |
return response.json() | |
else: | |
logger.error(f"Server responded with status code {response.status_code}") | |
raise gr.Error(f"Server responded with status code {response.status_code}") | |
except requests.exceptions.Timeout: | |
logger.error("The request timed out. The server might be overloaded.") | |
raise gr.Error("The request timed out. The server might be overloaded.") | |
except requests.exceptions.ConnectionError as e: | |
logger.error(f"Failed to connect to the server. Error: {str(e)}") | |
raise gr.Error("Failed to connect to the server. Please check your network connection.") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"An error occurred: {str(e)}") | |
raise gr.Error(f"An error occurred: {str(e)}") | |
except Exception as e: | |
logger.error(f"An unexpected error occurred: {str(e)}") | |
raise gr.Error(f"An unexpected error occurred: {str(e)}") | |
def get_output(user_id): | |
"""Retrieve the encrypted output (boolean). | |
Args: | |
user_id (int): The current user's ID. | |
Returns: | |
encrypted_output_short (bytes): A representation of the encrypted result. | |
""" | |
data = { | |
"user_id": user_id, | |
} | |
# Retrieve the encrypted output | |
url = SERVER_URL + "get_output" | |
with requests.post( | |
url=url, | |
data=data, | |
) as response: | |
if response.ok: | |
encrypted_output = response.content | |
# Save the encrypted output to bytes in a file as it is too large to pass through regular | |
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
encrypted_output_path = get_client_file_path("encrypted_output", user_id) | |
with encrypted_output_path.open("wb") as encrypted_output_file: | |
encrypted_output_file.write(encrypted_output) | |
# Create a truncated version of the encrypted output for display | |
encrypted_output_short = shorten_bytes_object(encrypted_output) | |
return encrypted_output_short | |
else: | |
raise gr.Error("Please wait for the FHE execution to be completed.") | |
def decrypt_output(user_id): | |
"""Decrypt the result. | |
Args: | |
user_id (int): The current user's ID. | |
Returns: | |
str: The decrypted output message | |
""" | |
if user_id == "": | |
raise gr.Error("Please generate the private key first.") | |
# Get the encrypted output path | |
encrypted_output_path = get_client_file_path("encrypted_output", user_id) | |
if not encrypted_output_path.is_file(): | |
raise gr.Error("Please run the FHE execution first.") | |
# Load the encrypted output as bytes | |
with encrypted_output_path.open("rb") as encrypted_output_file: | |
encrypted_output = encrypted_output_file.read() | |
logger.debug(f"Encrypted output size: {len(encrypted_output)} bytes") | |
logger.debug(f"Encrypted output (first 100 bytes): {encrypted_output[:100].hex()}") | |
if not encrypted_output: | |
raise gr.Error("The encrypted output is empty. Please try running the FHE execution again.") | |
# Retrieve the client API | |
client = get_client(user_id) | |
# Deserialize, decrypt and post-process the encrypted output | |
try: | |
decrypted_output = client.deserialize_decrypt_post_process(encrypted_output) | |
# The decrypted output should be a 1D array with 2 elements | |
if isinstance(decrypted_output, np.ndarray) and decrypted_output.shape == (2,): | |
predicted_class = np.argmax(decrypted_output) | |
confidence = decrypted_output[predicted_class] | |
result = "Seizure detected" if predicted_class == 1 else "No seizure detected" | |
return f"{result} (Confidence: {confidence:.2f})" | |
else: | |
logger.error(f"Unexpected decrypted output format: {decrypted_output}") | |
raise ValueError("Unexpected output format from the model") | |
except RuntimeError as e: | |
logger.error(f"Error during deserialization: {str(e)}") | |
raise gr.Error("Failed to deserialize the encrypted output. The data might be corrupted or in an unexpected format.") | |
except Exception as e: | |
logger.error(f"Unexpected error during decryption: {str(e)}") | |
raise gr.Error(f"An unexpected error occurred during decryption: {str(e)}") | |
def resize_img(img, width=256, height=256): | |
"""Resize the image.""" | |
if img.dtype != numpy.uint8: | |
img = img.astype(numpy.uint8) | |
img_pil = Image.fromarray(img) | |
# Resize the image | |
resized_img_pil = img_pil.resize((width, height)) | |
# Convert back to a NumPy array | |
return numpy.array(resized_img_pil) | |
demo = gr.Blocks() | |
print("Starting the demo...") | |
with demo: | |
gr.Markdown( | |
""" | |
<h1 align="center">Seizure Detection on Encrypted EEG Data Using Fully Homomorphic Encryption</h1> | |
""" | |
) | |
gr.Markdown("## Client side") | |
gr.Markdown("### Step 1: Upload an EEG image. ") | |
gr.Markdown( | |
f"The image will automatically be resized to shape (32x32). " | |
"The image here, however, is displayed in its original resolution." | |
) | |
with gr.Row(): | |
input_image = gr.Image( | |
value=None, label="Upload an EEG image here.", height=256, | |
width=256, sources="upload", interactive=True, | |
) | |
examples = gr.Examples( | |
examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use." | |
) | |
gr.Markdown("### Step 2: Generate the private key.") | |
keygen_button = gr.Button("Generate the private key.") | |
with gr.Row(): | |
keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False) | |
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) | |
gr.Markdown("### Step 3: Encrypt the image using FHE.") | |
encrypt_button = gr.Button("Encrypt the image using FHE.") | |
with gr.Row(): | |
encrypted_input = gr.Textbox( | |
label="Encrypted input representation:", max_lines=2, interactive=False | |
) | |
gr.Markdown("## Server side") | |
gr.Markdown( | |
"The encrypted value is received by the server. The server can then compute the seizure " | |
"detection directly over encrypted values. Once the computation is finished, the server returns " | |
"the encrypted results to the client." | |
) | |
gr.Markdown("### Step 4: Send the encrypted image to the server.") | |
send_input_button = gr.Button("Send the encrypted image to the server.") | |
send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False) | |
gr.Markdown("### Step 5: Run FHE execution.") | |
execute_fhe_button = gr.Button("Run FHE execution.") | |
fhe_execution_time = gr.Textbox( | |
label="Total FHE execution time (in seconds):", max_lines=1, interactive=False | |
) | |
gr.Markdown("### Step 6: Receive the encrypted output from the server.") | |
get_output_button = gr.Button("Receive the encrypted output from the server.") | |
with gr.Row(): | |
encrypted_output = gr.Textbox( | |
label="Encrypted output representation:", | |
max_lines=2, | |
interactive=False | |
) | |
gr.Markdown("## Client side") | |
gr.Markdown( | |
"The encrypted output is sent back to the client, who can finally decrypt it with the " | |
"private key. Only the client is aware of the original image and the detection result." | |
) | |
gr.Markdown("### Step 7: Decrypt the output.") | |
decrypt_button = gr.Button("Decrypt the output") | |
with gr.Row(): | |
decrypted_output = gr.Textbox( | |
label="Seizure detection result:", | |
interactive=False | |
) | |
# Button to generate the private key | |
keygen_button.click( | |
keygen, | |
outputs=[user_id, keygen_checkbox], | |
) | |
# Button to encrypt inputs on the client side | |
encrypt_button.click( | |
encrypt, | |
inputs=[user_id, input_image], | |
outputs=[input_image, encrypted_input], | |
) | |
# Button to send the encodings to the server using post method | |
send_input_button.click( | |
send_input, inputs=[user_id], outputs=[send_input_checkbox] | |
) | |
# Button to send the encodings to the server using post method | |
execute_fhe_button.click(run_fhe, inputs=[user_id], outputs=[fhe_execution_time]) | |
# Button to send the encodings to the server using post method | |
get_output_button.click( | |
get_output, | |
inputs=[user_id], | |
outputs=[encrypted_output] | |
) | |
# Button to decrypt the output on the client side | |
decrypt_button.click( | |
decrypt_output, | |
inputs=[user_id], | |
outputs=[decrypted_output], | |
) | |
gr.Markdown( | |
"The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " | |
"Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " | |
"Try it yourself and don't forget to star on Github ⭐." | |
) | |
demo.launch(share=False) | |