Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import random | |
from transformers import AutoImageProcessor, AutoModel | |
import torch | |
import timm | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import time | |
import subprocess | |
import os | |
from facenet_pytorch import MTCNN | |
mtcnn = MTCNN(keep_all=False) | |
def crop_face_to_112x112(image, count = None): | |
if image is None: | |
return None, None | |
if image.size == (112, 112): | |
processed = image | |
else: | |
boxes, _ = mtcnn.detect(image) | |
if boxes is None: | |
raise ValueError("No face detected.") | |
x1, y1, x2, y2 = map(int, boxes[0]) | |
cropped = image.crop((x1, y1, x2, y2)) | |
processed = cropped.resize((112, 112), Image.BILINEAR) | |
if count != None: | |
subjectDir = f'./uploaded_images/uploadedSubj' | |
os.makedirs(subjectDir, exist_ok=True) | |
img_path = f'{subjectDir}/uploadedSubj-{count}-img.jpg' | |
count =+1 | |
processed.save(img_path) | |
return processed, img_path, count # for enroll only | |
else: | |
subjectDir = f'./uploaded_probe_images/uploadedSubj' | |
os.makedirs(subjectDir, exist_ok=True) | |
img_path = f'{subjectDir}/img.jpg' | |
processed.save(img_path) | |
return processed, img_path | |
def numSTR(num): | |
return str(num) | |
SECURITYLEVELS = ["128", "196", "256"] | |
FRMODELS = ["gaunernst/vit_tiny_patch8_112.arcface_ms1mv3", | |
"gaunernst/vit_tiny_patch8_112.cosface_ms1mv3"] | |
def display_enrolled_image(): | |
file_path = './Server/subjOffsetMapping.txt' | |
subjects = set() | |
with open(file_path, 'r') as file: | |
for line in file: | |
parts = line.strip().split() | |
if parts: | |
subjects.add(parts[0]) | |
subjects = list(subjects) | |
enrDB = [img for subject in subjects for img in example_images if subject in img] | |
if 'uploadedSubj' in subjects: | |
dir_path = './uploaded_images/uploadedSubj/' | |
enrDB += [os.path.join(dir_path, f) for f in os.listdir(dir_path)] | |
return enrDB | |
def runBinFile(*args): | |
print(args) | |
binary_path = args[0] | |
if not os.path.isfile(binary_path): | |
return "Error: Compiled binary not found." | |
try: | |
if 'genkeys' in args: | |
runBinFile(*[args[0], args[1], 'delete']) | |
os.chmod(binary_path, 0o755) | |
start = time.time() | |
result = subprocess.run( | |
list(args), | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True | |
) | |
end = time.time() | |
duration = (end - start) | |
if 'print' in args: | |
return result.stdout | |
elif ('styledPrint' in args) and (result.returncode == 0): | |
return True, f"<b>β±οΈ Processing Time:</b> {duration:.2f} s", styled_output(result.stdout) | |
elif 'styledPrint' in args: | |
return styled_output(result.stdout) | |
elif result.returncode == 0: | |
return True, f"<b>β±οΈ Processing Time:</b> {duration:.2f} s" | |
else: | |
return False | |
except Exception as e: | |
return f"Execution failed: {e}" | |
example_images = ['./VGGFace2/n000001/0002_01.jpg', | |
'./VGGFace2/n000149/0002_01.jpg', | |
'./VGGFace2/n000082/0001_02.jpg', | |
'./VGGFace2/n000148/0014_01.jpg', | |
'./VGGFace2/n000129/0001_01.jpg', | |
'./VGGFace2/n000394/0007_01.jpg', | |
] | |
example_images_auth = ['./VGGFace2/n000001/0013_01.jpg', | |
'./VGGFace2/n000149/0019_01.jpg', | |
'./VGGFace2/n000082/0003_03.jpg', | |
'./VGGFace2/n000148/0043_01.jpg', | |
'./VGGFace2/n000129/0006_01.jpg', | |
'./VGGFace2/n000394/0018_01.jpg', | |
] | |
def display_image(image): | |
return image | |
def display_enrolled_image(): | |
file_path = './Server/subjOffsetMapping.txt' | |
subjects = set() | |
with open(file_path, 'r') as file: | |
for line in file: | |
parts = line.strip().split() | |
if parts: | |
subjects.add(parts[0]) | |
subjects = list(subjects) | |
enrDB = [img for subject in subjects for img in example_images if subject in img] | |
if 'uploadedSubj' in subjects: | |
dir_path = './uploaded_images/uploadedSubj/' | |
enrDB += [os.path.join(dir_path, f) for f in os.listdir(dir_path)] | |
return enrDB | |
def display_usedOffsets(): | |
file_path = './Server/subjOffsetMapping.txt' | |
offsets = set() | |
with open(file_path, 'r') as file: | |
for line in file: | |
parts = line.strip().split() | |
if parts: | |
offsets.add(parts[1]) | |
usedoff= list(offsets) | |
formatted = ", ".join(map(str, usedoff)) | |
return f"**Choose an available offset. Unavailable offsets :** {formatted}" | |
def load_rec_image(): | |
return f'static/reconstructed.png' | |
def extract_emb(image, modelName=FRMODELS[0], mode=None, imgPath=None): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.RandomHorizontalFlip(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
image = transform(image) | |
image = image.unsqueeze(0) | |
model = timm.create_model(f"hf_hub:{modelName}", pretrained=True).eval() | |
with torch.no_grad(): | |
embs = model(image) | |
embs = F.normalize(embs, dim=1) | |
embs = embs.detach().numpy() | |
embs = embs.squeeze(0) | |
if mode != None: | |
subject = imgPath.split('/')[-2] | |
os.makedirs(f'./embeddings/{subject}', exist_ok=True) | |
emb_path = f'./embeddings/{subject}/{mode}-emb.txt' | |
np.savetxt(emb_path, embs.reshape(1, embs.shape[0]), fmt="%.6f", delimiter=',') | |
return embs, emb_path | |
return embs | |
def get_selected_image(evt: gr.SelectData): | |
return example_images[evt.index], example_images[evt.index] | |
def get_selected_image_auth(evt: gr.SelectData): | |
return example_images_auth[evt.index], example_images_auth[evt.index] | |
def styled_output(result): | |
if result.strip().lower() == "found": | |
return "<span style='color: green; font-weight: bold;'>βοΈ Found</span>" | |
elif result.strip().lower() == "not found": | |
return "<span style='color: red; font-weight: bold;'>β Not Found</span>" | |
else: | |
return "<span style='color: red; font-weight: bold;'>Error</span>" | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1 align="center">Suraksh.AI</h1> | |
<p align="center"> | |
<a href="https://suraksh-ai.vercel.app/"> https://suraksh-ai.vercel.app/</a> | |
</p> | |
""" | |
) | |
gr.Markdown("# Biometric Search (1:N search) using Fully Homomorphic Encryption (FHE)") | |
gr.HTML( | |
""" | |
<p>This demo shows <strong>Suraksh.AI's</strong> biometric search solution under <strong>FHE</strong>.</p> | |
<ul> | |
<li><strong>Scenario 1</strong>: Searching an enrolled subject. For this scenario, the reference and probe should be from the same subject. Expected outcome: <span style='color: green; font-weight: bold;'>βοΈ Found</span></li> | |
<li><strong>Scenario 2</strong>: Searching an enrolled subject with high recognition threshold. For this scenario, the reference and probe should be from the same subject and the recognition threshold set to a high value. Expected outcome: <span style='color: red; font-weight: bold;'>β Not Found</span></li> | |
<li><strong>Scenario 3</strong>: Searching a non-enrolled subject. For this scenario, choose a probe not enrolled. Expected outcome: <span style='color: red; font-weight: bold;'>β Not Found</span></li> | |
<li><strong>Scenario 4</strong>: Searching a non-enrolled subject with low recognition threshold. For this scenario, choose a probe not enrolled and lower the high recognition threshold. Expected outcome: <span style='color: green; font-weight: bold;'>βοΈ Found</span></li> | |
</ul> | |
""" | |
) | |
with gr.Row(): | |
gr.Markdown("### Setup Phase: π Generate the FHE public and secret keys.") | |
with gr.Row(): | |
with gr.Column(): | |
securityLevel = gr.Dropdown( | |
choices=SECURITYLEVELS, | |
label="Choose a security level" | |
) | |
with gr.Column(): | |
key_button = gr.Button("Generate the FHE public and secret keys") | |
key_status = gr.Checkbox(label="FHE Public and Secret keys generated.", value=False) | |
time_output = gr.HTML() | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("genkeys")], outputs=[key_status,time_output]) | |
with gr.Row(): | |
gr.Markdown("## Phase 1: Enrollment") | |
with gr.Row(): | |
gr.Markdown("### Step 1: Upload or select a reference facial image for enrollment.") | |
with gr.Row(): | |
selectedImagePath = gr.State() | |
count = gr.State(0) | |
image_input_enroll = gr.Image(type="pil", visible=False) | |
with gr.Column(): | |
image_upload_enroll = gr.Image(label="Upload a reference facial image.", type="pil", sources="upload") | |
image_upload_enroll.change(fn=crop_face_to_112x112, inputs=[image_upload_enroll, count], outputs=[image_input_enroll, selectedImagePath, count]) | |
with gr.Column(): | |
refDB_gallery = gr.Gallery(value=example_images, columns=3) | |
refDB_gallery.select(fn=get_selected_image, inputs=None, outputs=[image_input_enroll, selectedImagePath]) | |
with gr.Column(): | |
selectedImage = gr.Image(type="pil", label="Reference facial image", interactive=False) | |
image_input_enroll.change(fn=lambda img: img, inputs=image_input_enroll, outputs=selectedImage) | |
with gr.Row(): | |
gr.Markdown("### Step 2: Generate reference embedding.") | |
with gr.Row(): | |
with gr.Column(): | |
modelName = gr.Dropdown( | |
choices=FRMODELS, | |
label="Choose a face recognition model" | |
) | |
with gr.Column(): | |
key_button = gr.Button("Generate embedding") | |
enroll_emb_text = gr.JSON(label="Reference embedding") | |
mode = gr.State("enroll") | |
ref_emb_path = gr.State() | |
key_button.click(fn=extract_emb, inputs=[image_input_enroll, modelName, mode, selectedImagePath], outputs=[enroll_emb_text, ref_emb_path]) | |
with gr.Row(): | |
gr.Markdown("""Facial embeddings are **INVERTIBLE** and lead to the **RECONSTRUCTION** of their raw facial images.""") | |
with gr.Row(): | |
gr.Markdown("### Example:") | |
with gr.Row(): | |
original_image = gr.Image(value="static/original.jpg", label="Original", sources="upload") | |
key_button = gr.Button("Generate embedding") | |
output_text = gr.JSON(label="Target embedding") | |
key_button.click(fn=extract_emb, inputs=[original_image, modelName], outputs=output_text) | |
btn = gr.Button("Reconstruct facial image") | |
Reconstructed_image = gr.Image(label="Reconstructed") | |
btn.click(fn=load_rec_image, outputs=Reconstructed_image) | |
with gr.Row(): | |
gr.Markdown("""Facial embeddings protection is a must! At **Suraksh.AI**, we protect facial embeddings using FHE.""") | |
with gr.Row(): | |
gr.Markdown("### Step 3: π Encrypt reference embedding using FHE.") | |
with gr.Row(): | |
gr.Markdown("### Set subject offset.") | |
with gr.Row(): | |
with gr.Column(): | |
usedOffsets = gr.Markdown() | |
subjOffset = gr.Number(value=1, label="Subject offset", info='Between 1 and 8192 for security levels of 128 and 192 bits and between 1 and 16384 for the security level 256 bits.') | |
subjOffset_txt = gr.Textbox(visible=False, value = '1') | |
subjOffset.change(fn=display_usedOffsets, inputs=None, outputs = usedOffsets) | |
subjOffset.change(fn=numSTR, inputs=subjOffset, outputs=subjOffset_txt) | |
with gr.Row(): | |
with gr.Column(): | |
key_button = gr.Button("Encrypt") | |
key_status = gr.Checkbox(label="Reference embedding encrypted.", value=False) | |
time_output = gr.HTML() | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("encRef"), ref_emb_path, subjOffset_txt], outputs=[key_status, time_output]) | |
with gr.Column(): | |
key_button = gr.Button("Display") | |
output_text = gr.Text(label="Encrypted embedding", lines=3, interactive=False) | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("printVectorCipher"), gr.State("encRef"), gr.State("print")], outputs=output_text) | |
with gr.Row(): | |
gr.Markdown("### Step 4: π Add encrypt reference to the encrypted reference DB.") | |
with gr.Row(): | |
with gr.Column(): | |
add_button = gr.Button("Add") | |
add_status = gr.Checkbox(label="Encrypted reference DB updated with the encrypted reference.", value=False) | |
time_output = gr.HTML() | |
add_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("addRef")], outputs=[add_status, time_output]) | |
with gr.Column(): | |
key_button = gr.Button("Display") | |
output_text = gr.Text(label="Encrypted reference DB", lines=3, interactive=False) | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("printVectorCipher"), gr.State("encRefDB"), gr.State("print")], outputs=output_text) | |
with gr.Row(): | |
gr.Markdown("## Phase 2: Search") | |
with gr.Row(): | |
gr.Markdown("### Step 1: Upload or select a probe facial image for search.") | |
with gr.Row(): | |
selectedImagePath_auth = gr.State() | |
image_input_auth = gr.Image(type="pil", visible=False) | |
with gr.Column(): | |
image_upload_auth = gr.Image(label="Upload a facial image.", type="pil", sources="upload") | |
image_upload_auth.change(fn=crop_face_to_112x112, inputs=image_upload_auth, outputs=[image_input_auth,selectedImagePath_auth]) | |
with gr.Column(): | |
prob_gallery = gr.Gallery(value=example_images_auth, columns=3) | |
prob_gallery.select(fn=get_selected_image_auth, inputs=None, outputs=[image_input_auth,selectedImagePath_auth]) | |
with gr.Column(): | |
selectedImage = gr.Image(type="pil", label="Probe facial image", interactive=False) | |
image_input_auth.change(fn=lambda img: img, inputs=image_input_auth, outputs=selectedImage) | |
with gr.Row(): | |
gr.Markdown("### Step 2: Generate probe facial embedding.") | |
with gr.Row(): | |
with gr.Column(): | |
key_button = gr.Button("Generate embedding") | |
prob_emb_text = gr.JSON(label="Probe embedding") | |
mode = gr.State("auth") | |
prob_emb_path = gr.State() | |
key_button.click(fn=extract_emb, inputs=[image_input_auth, modelName, mode, selectedImagePath_auth], outputs=[prob_emb_text,prob_emb_path]) | |
with gr.Row(): | |
gr.Markdown("### Step 3: π Generate protected probe embedding.") | |
with gr.Row(): | |
with gr.Column(): | |
key_button = gr.Button("Protect") | |
key_status = gr.Checkbox(label="Probe embedding protected.", value=False) | |
time_output = gr.HTML() | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("encProbe"), prob_emb_path], outputs=[key_status,time_output]) | |
with gr.Column(): | |
key_button = gr.Button("Display") | |
output_text = gr.Text(label="Protected embedding", lines=3, interactive=False) | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("printProbe"), gr.State("print")], outputs=output_text) | |
with gr.Row(): | |
gr.Markdown("### Step 4: π Run encrypted biometric search.") | |
with gr.Row(): | |
with gr.Column(): | |
key_button = gr.Button("Biometric search under FHE") | |
key_status = gr.Checkbox(label="Search decision encrypted.", value=False) | |
time_output = gr.HTML() | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("search")], outputs=[key_status,time_output]) | |
with gr.Column(): | |
key_button = gr.Button("Display") | |
output_text = gr.Text(label="Encrypted scores", lines=3, interactive=False) | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("printEncScores"), gr.State("print")], outputs=output_text) | |
with gr.Row(): | |
gr.Markdown("### Step 5: π Decrypt scores and make a decision.") | |
with gr.Row(): | |
gr.Markdown("### Set the recognition threshold.") | |
with gr.Row(): | |
slider_threshold = gr.Slider(-512*5, 512*5, step=1, value=133, label="Decision threshold", info="The higher the stricter.", interactive=True) | |
number_threshold = gr.Textbox(visible=False, value = '133') | |
slider_threshold.change(fn=lambda x: x, inputs=slider_threshold, outputs=number_threshold) | |
with gr.Row(): | |
key_button = gr.Button("Decrypt and Decide") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
key_status = gr.Checkbox(label="Recognition decision encrypted.", value=False) | |
time_output = gr.HTML() | |
with gr.Column(scale=1): | |
final_output = gr.HTML() | |
key_button.click(fn=runBinFile, inputs=[gr.State("./bin/search.bin"), securityLevel, gr.State("decDecisionClear"), number_threshold, gr.State("styledPrint")], outputs=[key_status, time_output, final_output]) | |
with gr.Column(scale=1): | |
image_output_auth = gr.Image(label="Probe", sources="upload") | |
image_input_auth.change(fn=display_image, inputs=image_input_auth, outputs=image_output_auth) | |
btn = gr.Button("Display Reference DB", scale=0) | |
with gr.Column(scale=1): | |
output_gallery = gr.Gallery(label="Reference DB",columns=3) | |
btn.click(display_enrolled_image, None, output_gallery) | |
demo.launch() | |