Spaces:
Running
Running
import os | |
import zipfile | |
import torch | |
import clip | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import openai | |
from tqdm import tqdm | |
from glob import glob | |
import psycopg2 | |
from psycopg2.extras import execute_values | |
import json | |
import time | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 1: UNZIP TO CORRECT STRUCTURE | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
zip_name = "lfw-faces.zip" | |
unzip_dir = "lfw-faces" | |
if not os.path.exists(unzip_dir): | |
print("π Unzipping...") | |
with zipfile.ZipFile(zip_name, "r") as zip_ref: | |
zip_ref.extractall(unzip_dir) | |
print("β Unzipped into:", unzip_dir) | |
# True image root after unzip | |
img_root = os.path.join(unzip_dir, "lfw-deepfunneled") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# ποΈ STEP 2: DATABASE SETUP | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def setup_database(): | |
"""Setup PostgreSQL with pgvector extension""" | |
# Database configuration | |
DB_CONFIG = { | |
"dbname": "face_matcher", | |
"user": "postgres", | |
"password": "postgres", # Change this to your actual password | |
"host": "localhost", | |
"port": "5432" | |
} | |
try: | |
# Connect to PostgreSQL server to create database if it doesn't exist | |
conn = psycopg2.connect( | |
dbname="postgres", | |
user=DB_CONFIG["user"], | |
password=DB_CONFIG["password"], | |
host=DB_CONFIG["host"] | |
) | |
conn.autocommit = True | |
cur = conn.cursor() | |
# Create database if it doesn't exist | |
cur.execute(f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{DB_CONFIG['dbname']}'") | |
exists = cur.fetchone() | |
if not exists: | |
cur.execute(f"CREATE DATABASE {DB_CONFIG['dbname']}") | |
print(f"Database {DB_CONFIG['dbname']} created.") | |
cur.close() | |
conn.close() | |
# Connect to the face_matcher database | |
conn = psycopg2.connect(**DB_CONFIG) | |
conn.autocommit = True | |
cur = conn.cursor() | |
# Create pgvector extension if it doesn't exist | |
cur.execute("CREATE EXTENSION IF NOT EXISTS vector") | |
# Create faces table if it doesn't exist | |
cur.execute(""" | |
CREATE TABLE IF NOT EXISTS faces ( | |
id SERIAL PRIMARY KEY, | |
path TEXT UNIQUE NOT NULL, | |
name TEXT NOT NULL, | |
embedding vector(512), | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
""") | |
# Create index on the embedding column | |
cur.execute("CREATE INDEX IF NOT EXISTS faces_embedding_idx ON faces USING ivfflat (embedding vector_ip_ops)") | |
print("β Database setup complete.") | |
return conn | |
except Exception as e: | |
print(f"β Database setup failed: {e}") | |
return None | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π§ STEP 3: LOAD CLIP MODEL | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
print(f"β CLIP model loaded on {device}") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 4: EMBEDDING FUNCTIONS | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def embed_image(image_path): | |
"""Generate CLIP embedding for a single image""" | |
try: | |
img = Image.open(image_path).convert("RGB") | |
img_input = preprocess(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
emb = model.encode_image(img_input).cpu().numpy().flatten() | |
emb /= np.linalg.norm(emb) | |
return emb | |
except Exception as e: | |
print(f"β οΈ Error embedding {image_path}: {e}") | |
return None | |
def populate_database(conn, limit=500): | |
"""Populate database with images and their embeddings""" | |
# Collect all .jpg files inside subfolders | |
all_images = sorted(glob(os.path.join(img_root, "*", "*.jpg"))) | |
selected_images = all_images[:limit] | |
if len(selected_images) == 0: | |
raise RuntimeError("β No image files found in unzipped structure!") | |
cur = conn.cursor() | |
# Check which images are already in the database | |
cur.execute("SELECT path FROM faces") | |
existing_paths = set(path[0] for path in cur.fetchall()) | |
# Filter out images that are already in the database | |
new_images = [path for path in selected_images if path not in existing_paths] | |
if not new_images: | |
print("β All images are already in the database.") | |
return | |
print(f"π§ Generating CLIP embeddings for {len(new_images)} new images...") | |
# Process images in batches to avoid memory issues | |
batch_size = 50 | |
for i in range(0, len(new_images), batch_size): | |
batch = new_images[i:i+batch_size] | |
data_to_insert = [] | |
for fpath in tqdm(batch, desc=f"Embedding batch {i//batch_size + 1}"): | |
try: | |
emb = embed_image(fpath) | |
if emb is not None: | |
name = os.path.splitext(os.path.basename(fpath))[0].replace("_", " ") | |
data_to_insert.append((fpath, name, emb.tolist())) | |
except Exception as e: | |
print(f"β οΈ Error with {fpath}: {e}") | |
# Insert batch into database | |
if data_to_insert: | |
execute_values( | |
cur, | |
"INSERT INTO faces (path, name, embedding) VALUES %s ON CONFLICT (path) DO NOTHING", | |
[(d[0], d[1], d[2]) for d in data_to_insert], | |
template="(%s, %s, %s::vector)" | |
) | |
conn.commit() | |
# Count total faces in database | |
cur.execute("SELECT COUNT(*) FROM faces") | |
total_faces = cur.fetchone()[0] | |
print(f"β Database now contains {total_faces} faces.") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 5: LOAD OPENAI API KEY | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 6: FACE MATCHING FUNCTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def scan_face(user_image, conn): | |
"""Scan a face image and find matches in the database""" | |
if user_image is None: | |
return [], "", "", "Please upload a face image." | |
try: | |
user_image = user_image.convert("RGB") | |
tensor = preprocess(user_image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
query_emb = model.encode_image(tensor).cpu().numpy().flatten() | |
query_emb /= np.linalg.norm(query_emb) | |
except Exception as e: | |
return [], "", "", f"Image preprocessing failed: {e}" | |
# Query database for similar faces | |
cur = conn.cursor() | |
emb_list = query_emb.tolist() | |
cur.execute(""" | |
SELECT path, name, embedding <-> %s::vector AS distance | |
FROM faces | |
ORDER BY distance | |
LIMIT 5 | |
""", (emb_list,)) | |
results = cur.fetchall() | |
gallery, captions, names = [], [], [] | |
scores = [] | |
for path, name, distance in results: | |
try: | |
# Convert distance to similarity score (1 - distance) | |
similarity = 1 - distance | |
scores.append(similarity) | |
img = Image.open(path) | |
gallery.append(img) | |
captions.append(f"{name} (Score: {similarity:.2f})") | |
names.append(name) | |
except Exception as e: | |
captions.append(f"β οΈ Error loading match image: {e}") | |
risk_score = min(100, int(np.mean(scores) * 100)) if scores else 0 | |
# π§ GPT-4 EXPLANATION | |
try: | |
prompt = ( | |
f"The uploaded face matches closely with: {', '.join(names)}. " | |
f"Based on this, should the user be suspicious? Analyze like a funny but smart AI dating detective." | |
) | |
response = openai.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You're a playful but intelligent AI face-matching analyst."}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
explanation = response.choices[0].message.content | |
except Exception as e: | |
explanation = f"(OpenAI error): {e}" | |
return gallery, "\n".join(captions), f"{risk_score}/100", explanation | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π± STEP 7: ADD NEW FACE FUNCTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def add_new_face(image, name, conn): | |
"""Add a new face to the database""" | |
if image is None or not name: | |
return "Please provide both an image and a name." | |
try: | |
# Save image to a temporary file | |
timestamp = int(time.time()) | |
os.makedirs("uploaded_faces", exist_ok=True) | |
path = f"uploaded_faces/{name.replace(' ', '_')}_{timestamp}.jpg" | |
image.save(path) | |
# Generate embedding | |
emb = embed_image(path) | |
if emb is None: | |
return "Failed to generate embedding for the image." | |
# Add to database | |
cur = conn.cursor() | |
cur.execute( | |
"INSERT INTO faces (path, name, embedding) VALUES (%s, %s, %s::vector)", | |
(path, name, emb.tolist()) | |
) | |
conn.commit() | |
return f"β Added {name} to the database successfully!" | |
except Exception as e: | |
return f"β Failed to add face: {e}" | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# ποΈ STEP 8: GRADIO UI | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def create_ui(): | |
"""Create Gradio UI with both scan and add functionality""" | |
# Setup database connection | |
conn = setup_database() | |
if conn is None: | |
raise RuntimeError("β Database connection failed. Please check your PostgreSQL installation and pgvector extension.") | |
# Populate database with initial images | |
populate_database(conn) | |
# Wrapper functions for Gradio that use the database connection | |
def scan_face_wrapper(image): | |
return scan_face(image, conn) | |
def add_face_wrapper(image, name): | |
return add_new_face(image, name, conn) | |
with gr.Blocks(title="Tinder Scanner β Real Face Match Detector") as demo: | |
gr.Markdown("# Tinder Scanner β Real Face Match Detector") | |
gr.Markdown("Scan a face image to find visual matches using CLIP and PostgreSQL, and get a cheeky GPT-4 analysis.") | |
with gr.Tab("Scan Face"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload a Face Image") | |
scan_button = gr.Button("π Scan Face") | |
with gr.Column(): | |
gallery = gr.Gallery(label="π Top Matches", columns=[5], height="auto") | |
captions = gr.Textbox(label="Match Names + Similarity Scores") | |
risk_score = gr.Textbox(label="π¨ Cheating Risk Score") | |
explanation = gr.Textbox(label="π§ GPT-4 Explanation", lines=5) | |
scan_button.click( | |
fn=scan_face_wrapper, | |
inputs=[input_image], | |
outputs=[gallery, captions, risk_score, explanation] | |
) | |
with gr.Tab("Add New Face"): | |
with gr.Row(): | |
with gr.Column(): | |
new_image = gr.Image(type="pil", label="Upload New Face Image") | |
new_name = gr.Textbox(label="Person's Name") | |
add_button = gr.Button("β Add to Database") | |
with gr.Column(): | |
result = gr.Textbox(label="Result") | |
add_button.click( | |
fn=add_face_wrapper, | |
inputs=[new_image, new_name], | |
outputs=result | |
) | |
return demo | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π MAIN EXECUTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.launch() |