HyperFace / app.py
otroshi's picture
Update app.py
b844c64 verified
raw
history blame
13.1 kB
# SPDX-FileCopyrightText: 2025 Idiap Research Institute
# SPDX-FileContributor: Hatef Otroshi <[email protected]>
# SPDX-License-Identifier: MIT
"""HyperFace demo"""
from __future__ import annotations
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from huggingface_hub import hf_hub_download
from title import title_css, title_with_logo
from face_alignment import align
from PIL import Image
import net
model_configs = {
"HyperFace-10k-LDM": {
"repo": "idiap/HyperFace-10k-LDM",
"filename": "HyperFace_10k_LDM.ckpt",
},
"HyperFace-10k-StyleGAN": {
"repo": "idiap/HyperFace-10k-StyleGAN",
"filename": "HyperFace_10k_StyleGAN.ckpt",
},
"HyperFace-50k-StyleGAN": {
"repo": "idiap/HyperFace-50k-StyleGAN",
"filename": "HyperFace_50k_StyleGAN.ckpt",
},
}
# ───────────────────────────────
# Data & models
# ───────────────────────────────
DATA_DIR = Path("data")
EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
PRELOADED = sorted(p for p in DATA_DIR.iterdir() if p.suffix.lower() in EXTS)
HYPERFACE_MODELS = [
"HyperFace-10k-LDM",
"HyperFace-10k-StyleGAN",
"HyperFace-50k-StyleGAN",
]
# ───────────────────────────────
# Styling (orange palette)
# ───────────────────────────────
PRIMARY = "#F97316"
PRIMARY_DARK = "#C2410C"
ACCENT_LIGHT = "#FFEAD2"
BG_LIGHT = "#FFFBF7"
CARD_BG_DARK = "#473f38"
BG_DARK = "#332a22"
TEXT_DARK = "#0F172A"
TEXT_LIGHT = "#f8fafc"
CSS = f"""
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
/* ─── palette ───────────────────────────────────────────── */
body, .gradio-container {{
font-family: 'Inter', sans-serif;
background: {BG_LIGHT};
color: {TEXT_DARK};
}}
a {{
color: {PRIMARY};
text-decoration: none;
font-weight: 600;
}}
a:hover {{ color: {PRIMARY_DARK}; }}
/* ─── headline ──────────────────────────────────────────── */
#titlebar {{
text-align: center;
margin-top: 2.4rem;
margin-bottom: .9rem;
}}
/* ─── card look ─────────────────────────────────────────── */
.gr-block,
.gr-box,
.gr-row,
#cite-wrapper {{
border: 1px solid #F8C89B;
border-radius: 10px;
background: #fff;
box-shadow: 0 3px 6px rgba(0, 0, 0, .05);
}}
.gr-gallery-item {{ background: #fff; }}
/* ─── controls / inputs ─────────────────────────────────── */
.gr-button-primary,
#copy-btn {{
background: linear-gradient(90deg, {PRIMARY} 0%, {PRIMARY_DARK} 100%);
border: none;
color: #fff;
border-radius: 6px;
font-weight: 600;
transition: transform .12s ease, box-shadow .12s ease;
}}
.gr-button-primary:hover,
#copy-btn:hover {{
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(249, 115, 22, .35);
}}
.gr-dropdown input {{
border: 1px solid {PRIMARY}99;
}}
.preview img,
.preview canvas {{ object-fit: contain !important; }}
/* ─── hero section ─────────────────────────────────────── */
#hero-wrapper {{ text-align: center; }}
#hero-badge {{
display: inline-block;
padding: .85rem 1.2rem;
border-radius: 8px;
background: {ACCENT_LIGHT};
border: 1px solid {PRIMARY}55;
font-size: .95rem;
font-weight: 600;
margin-bottom: .5rem;
}}
#hero-links {{
font-size: .95rem;
font-weight: 600;
margin-bottom: 1.6rem;
}}
#hero-links img {{
height: 22px;
vertical-align: middle;
margin-left: .55rem;
}}
/* ─── score area ───────────────────────────────────────── */
#score-area {{
text-align: center;
}}
.title-container {{
display: flex;
align-items: center;
gap: 12px;
justify-content: center;
margin-bottom: 10px;
text-align: center;
}}
.match-badge {{
display: inline-block;
padding: .35rem .9rem;
border-radius: 9999px;
font-weight: 600;
font-size: 1.25rem;
}}
/* ─── citation card ────────────────────────────────────── */
#cite-wrapper {{
position: relative;
padding: .9rem 1rem;
margin-top: 2rem;
}}
#cite-wrapper code {{
font-family: SFMono-Regular, Consolas, monospace;
font-size: .84rem;
white-space: pre-wrap;
color: {TEXT_DARK};
}}
#copy-btn {{
position: absolute;
top: .55rem;
right: .6rem;
padding: .18rem .7rem;
font-size: .72rem;
line-height: 1;
}}
/* ─── dark mode ────────────────────────────────────── */
.dark body,
.dark .gradio-container {{
background-color: {BG_DARK};
color: #e5e7eb;
}}
.dark .gr-block,
.dark .gr-box,
.dark .gr-row {{
background-color: {BG_DARK};
border: 1px solid #4b5563;
}}
.dark .gr-dropdown input {{
background-color: {BG_DARK};
color: #f1f5f9;
border: 1px solid {PRIMARY}aa;
}}
.dark #hero-badge {{
background: #334155;
border: 1px solid {PRIMARY}55;
color: #fefefe;
}}
.dark #cite-wrapper {{
background-color: {CARD_BG_DARK};
}}
.dark #bibtex {{
color: {TEXT_LIGHT} !important;
}}
.dark .card {{
background-color: {CARD_BG_DARK};
}}
/* ─── switch logo for light/dark theme ─────────────── */
.logo-dark {{ display: none; }}
.dark .logo-light {{ display: none; }}
.dark .logo-dark {{ display: inline; }}
"""
FULL_CSS = CSS + title_css(TEXT_DARK, PRIMARY, PRIMARY_DARK, TEXT_LIGHT)
# ───────────────────────────────
# Torch / transforms
# ───────────────────────────────
def to_input(pil_rgb_image):
np_img = np.array(pil_rgb_image)
brg_img = ((np_img[:,:,::-1] / 255.) - 0.5) / 0.5
tensor = torch.tensor([brg_img.transpose(2,0,1)]).float()
return tensor
def get_face_rec_model(name: str) -> torch.nn.Module:
if name not in get_face_rec_model.cache:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = hf_hub_download(
repo_id=model_configs[name]["repo"],
filename=model_configs[name]["filename"],
local_dir="models",
)
model = net.build_model(model_name='ir_50')
statedict = torch.load(model_path, map_location=device)['state_dict']
model_statedict = {key[6:]:val for key, val in statedict.items() if key.startswith('model.')}
model.load_state_dict(model_statedict)
model.eval()
model.to(device)
get_face_rec_model.cache[name] = model
return get_face_rec_model.cache[name]
get_face_rec_model.cache = {}
# ───────────────────────────────
# Helpers
# ───────────────────────────────
def _as_rgb(path: Path) -> np.ndarray:
return cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
def badge(text: str, colour: str) -> str:
return f'<div class="match-badge" style="background:{colour}22;color:{colour}">{text}</div>'
# ───────────────────────────────
# Face comparison
# ───────────────────────────────
def compare(img_left, img_right, variant):
if img_left is None and img_right is None:
return None, None, badge("Please upload/select two face images", "#DC2626")
if img_left is None:
return None, None, badge("Please upload/select a face image for Image A (left)", "#DC2626")
if img_right is None:
return None, None, badge("Please upload/select a face image for Image B (right)", "#DC2626")
img_left = Image.fromarray(img_left).convert('RGB')
img_right = Image.fromarray(img_right).convert('RGB')
crop_a, crop_b = align.get_aligned_face(None, img_left), align.get_aligned_face(None, img_right)
if crop_a is None and crop_b is None:
return None, None, badge("No face detected", "#DC2626")
if crop_a is None:
return None, None, badge("No face was detected in Image A (left)", "#DC2626")
if crop_b is None:
return None, None, badge("No face was detected in Image B (right)", "#DC2626")
mdl = get_face_rec_model(variant)
dev = next(mdl.parameters()).device
with torch.no_grad():
ea = mdl(to_input(crop_a).to(dev))[0]
eb = mdl(to_input(crop_b).to(dev))[0]
pct = float(F.cosine_similarity(ea, eb).item() * 100)
pct = max(0, min(100, pct))
colour = "#15803D" if pct >= 70 else "#CA8A04" if pct >= 40 else "#DC2626"
return crop_a, crop_b, badge(f"{pct:.2f}% match", colour)
# ───────────────────────────────
# Static HTML
# ───────────────────────────────
TITLE_HTML = title_with_logo(
"""<span class="brand">HyperFace:</span> Generating Synthetic Face Recognition Datasets by Exploring Face Embedding Hypersphere
"""
)
HERO_HTML = f"""
<div id="hero-wrapper">
<div id="hero-links">
<a href="https://www.idiap.ch/paper/hyperface/">Project</a>&nbsp;β€’&nbsp;
<a href="https://openreview.net/pdf?id=4YzVF9isgD">Paper</a>&nbsp;β€’&nbsp;
<a href="https://arxiv.org/abs/2411.08470v2">arXiv</a>&nbsp;β€’&nbsp;
<a href="https://gitlab.idiap.ch/biometric/code.iclr2025_hyperface">Code</a>&nbsp;β€’&nbsp;
<a href="https://huggingface.co/collections/Idiap/hyperface-682485119ccbd3ba5c42bde1">Models</a>&nbsp;β€’&nbsp;
<a href="https://zenodo.org/records/15087238">Dataset</a>
</div>
</div>
"""
CITATION_HTML = """
<div id="cite-wrapper">
<button id="copy-btn" onclick="
navigator.clipboard.writeText(document.getElementById('bibtex').innerText)
.then(()=>{this.textContent='βœ”οΈŽ';setTimeout(()=>this.textContent='Copy',1500);});
">Copy</button>
<code id="bibtex">
@inproceedings{shahreza2025hyperface,
title={HyperFace: Generating Synthetic Face Recognition Datasets by Exploring Face Embedding Hypersphere},
author={Hatef Otroshi Shahreza and S{\'e}bastien Marcel},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025}
}</code>
</div>
"""
# ───────────────────────────────
# Gradio UI
# ───────────────────────────────
with gr.Blocks(css=FULL_CSS, title="HyperFace Demo") as demo:
gr.HTML(TITLE_HTML, elem_id="titlebar")
gr.HTML(HERO_HTML)
with gr.Row():
gal_a = gr.Gallery(
PRELOADED,
columns=[5],
height=120,
label="Image A",
object_fit="contain",
elem_classes="card",
)
gal_b = gr.Gallery(
PRELOADED,
columns=[5],
height=120,
label="Image B",
object_fit="contain",
elem_classes="card",
)
with gr.Row():
img_a = gr.Image(
type="numpy",
height=300,
label="Image A (click or drag-drop)",
interactive=True,
elem_classes="preview card",
)
img_b = gr.Image(
type="numpy",
height=300,
label="Image B (click or drag-drop)",
interactive=True,
elem_classes="preview card",
)
def _fill(evt: gr.SelectData):
return _as_rgb(PRELOADED[evt.index]) if evt.index is not None else None
gal_a.select(_fill, outputs=img_a)
gal_b.select(_fill, outputs=img_b)
variant_dd = gr.Dropdown(
HYPERFACE_MODELS, value="HyperFace-10k-LDM", label="Model variant", elem_classes="card"
)
btn = gr.Button("Compare", variant="primary")
with gr.Row():
out_a = gr.Image(label="Aligned A (112Γ—112)", elem_classes="card")
out_b = gr.Image(label="Aligned B (112Γ—112)", elem_classes="card")
score_html = gr.HTML(elem_id="score-area")
btn.click(compare, [img_a, img_b, variant_dd], [out_a, out_b, score_html])
gr.HTML(CITATION_HTML)
# ───────────────────────────────
if __name__ == "__main__":
demo.launch(share=True)