batuergun commited on
Commit
5bebde6
1 Parent(s): 7f46797
Files changed (1) hide show
  1. app.py +56 -25
app.py CHANGED
@@ -13,6 +13,10 @@ import requests
13
  from requests.adapters import HTTPAdapter
14
  from requests.packages.urllib3.util.retry import Retry
15
  import logging
 
 
 
 
16
 
17
  from common import (
18
  CLIENT_TMP_PATH,
@@ -310,36 +314,58 @@ def get_output(user_id):
310
  else:
311
  raise gr.Error("Please wait for the FHE execution to be completed.")
312
 
313
- def decrypt_output(user_id):
314
- """Decrypt the result.
 
 
315
 
316
- Args:
317
- user_id (int): The current user's ID.
318
 
319
- Returns:
320
- bool: The decrypted output (True if seizure detected, False otherwise)
321
 
322
- """
323
- if user_id == "":
324
- raise gr.Error("Please generate the private key first.")
325
 
326
- # Get the encrypted output path
327
- encrypted_output_path = get_client_file_path("encrypted_output", user_id)
328
 
329
- if not encrypted_output_path.is_file():
330
- raise gr.Error("Please run the FHE execution first.")
331
 
332
- # Load the encrypted output as bytes
333
- with encrypted_output_path.open("rb") as encrypted_output_file:
334
- encrypted_output = encrypted_output_file.read()
335
 
336
- # Retrieve the client API
337
- client = get_client(user_id)
338
 
339
- # Deserialize, decrypt and post-process the encrypted output
340
- decrypted_output = client.deserialize_decrypt_post_process(encrypted_output)
 
 
 
 
341
 
342
- return "Seizure detected" if decrypted_output else "No seizure detected"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  def resize_img(img, width=256, height=256):
345
  """Resize the image."""
@@ -430,10 +456,15 @@ with demo:
430
  decrypt_button = gr.Button("Decrypt the output")
431
 
432
  with gr.Row():
433
- decrypted_output = gr.Textbox(
434
- label="Seizure detection result:",
435
- interactive=False
436
- )
 
 
 
 
 
437
 
438
  # Button to generate the private key
439
  keygen_button.click(
 
13
  from requests.adapters import HTTPAdapter
14
  from requests.packages.urllib3.util.retry import Retry
15
  import logging
16
+ import numpy as np
17
+ import seaborn as sns
18
+ import io
19
+ import matplotlib.pyplot as plt
20
 
21
  from common import (
22
  CLIENT_TMP_PATH,
 
314
  else:
315
  raise gr.Error("Please wait for the FHE execution to be completed.")
316
 
317
+ def decrypt_output(user_id, encrypted_output):
318
+ """Decrypt the output of the seizure detection."""
319
+ if user_id == "":
320
+ raise gr.Error("Please generate the private key first.")
321
 
322
+ if encrypted_output is None:
323
+ raise gr.Error("Please run the FHE computation first.")
324
 
325
+ # Retrieve the client API
326
+ client = get_client(user_id)
327
 
328
+ try:
329
+ # Deserialize and decrypt the output
330
+ decrypted_output = client.deserialize_decrypt_post_process(encrypted_output)
331
 
332
+ # Reshape the output to match the expected shape
333
+ decrypted_output = decrypted_output.reshape(1, 1, 32, 32, 3)
334
 
335
+ # Convert the output to a probability
336
+ probability = sigmoid(decrypted_output[0, 0, 0, 0, 0])
337
 
338
+ # Create a heatmap
339
+ heatmap = create_heatmap(decrypted_output[0, 0, :, :, 0])
 
340
 
341
+ return probability, heatmap
 
342
 
343
+ except Exception as e:
344
+ logger.error(f"Error in decrypt_output: {str(e)}")
345
+ raise gr.Error(f"Decryption failed: {str(e)}")
346
+
347
+ def sigmoid(x):
348
+ return 1 / (1 + np.exp(-x))
349
 
350
+ def create_heatmap(data):
351
+ # Normalize the data
352
+ data_normalized = (data - np.min(data)) / (np.max(data) - np.min(data))
353
+
354
+ # Create a heatmap
355
+ plt.figure(figsize=(6, 6))
356
+ sns.heatmap(data_normalized, cmap='YlOrRd', cbar=True)
357
+ plt.title('Activation Heatmap')
358
+ plt.axis('off')
359
+
360
+ # Save the plot to a BytesIO object
361
+ buf = io.BytesIO()
362
+ plt.savefig(buf, format='png')
363
+ buf.seek(0)
364
+
365
+ # Clear the current figure
366
+ plt.clf()
367
+
368
+ return buf
369
 
370
  def resize_img(img, width=256, height=256):
371
  """Resize the image."""
 
456
  decrypt_button = gr.Button("Decrypt the output")
457
 
458
  with gr.Row():
459
+ probability_output = gr.Number(label="Seizure Probability")
460
+ heatmap_output = gr.Image(label="Activation Heatmap")
461
+
462
+ # Update the decrypt button click event
463
+ decrypt_button.click(
464
+ decrypt_output,
465
+ inputs=[user_id, encrypted_output],
466
+ outputs=[probability_output, heatmap_output]
467
+ )
468
 
469
  # Button to generate the private key
470
  keygen_button.click(