Spaces:
Running
Running
import os | |
import zipfile | |
import torch | |
import clip | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import openai | |
from tqdm import tqdm | |
from glob import glob | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import json | |
import time | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 1: UNZIP TO CORRECT STRUCTURE | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
zip_name = "lfw-faces.zip" | |
unzip_dir = "lfw-faces" | |
if not os.path.exists(unzip_dir): | |
print("π Unzipping...") | |
with zipfile.ZipFile(zip_name, "r") as zip_ref: | |
zip_ref.extractall(unzip_dir) | |
print("β Unzipped into:", unzip_dir) | |
# True image root after unzip | |
img_root = os.path.join(unzip_dir, "lfw-deepfunneled") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π§ STEP 2: LOAD CLIP MODEL | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
print(f"β CLIP model loaded on {device}") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# ποΈ STEP 3: CHROMA DB SETUP & EMBEDDING FUNCTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
class ClipEmbeddingFunction: | |
"""Custom embedding function for Chroma DB using CLIP""" | |
def __init__(self, model, preprocess, device): | |
self.model = model | |
self.preprocess = preprocess | |
self.device = device | |
def __call__(self, images): | |
"""Generate embeddings for a list of image paths""" | |
embeddings = [] | |
for image_path in images: | |
try: | |
# Check if the path is a string (for new additions from disk) | |
if isinstance(image_path, str) and os.path.exists(image_path): | |
img = Image.open(image_path).convert("RGB") | |
else: | |
# For query images that are already PIL images | |
img = image_path.convert("RGB") if hasattr(image_path, 'convert') else image_path | |
img_input = self.preprocess(img).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
emb = self.model.encode_image(img_input).cpu().numpy().flatten() | |
emb /= np.linalg.norm(emb) | |
embeddings.append(emb.tolist()) | |
except Exception as e: | |
print(f"β οΈ Error embedding image: {e}") | |
# Return a zero vector as fallback | |
embeddings.append([0] * 512) | |
return embeddings | |
def setup_database(): | |
"""Setup ChromaDB with CLIP embedding function""" | |
try: | |
# Create persistent client | |
client = chromadb.PersistentClient(path="./chroma_db") | |
# Create custom embedding function | |
embedding_function = ClipEmbeddingFunction(model, preprocess, device) | |
# Create or get existing collection | |
collection = client.get_or_create_collection( | |
name="faces", | |
embedding_function=embedding_function, | |
metadata={"hnsw:space": "cosine"} # Use cosine similarity | |
) | |
print("β ChromaDB setup complete.") | |
return client, collection | |
except Exception as e: | |
print(f"β Database setup failed: {e}") | |
return None, None | |
def populate_database(collection, limit=500): | |
"""Populate ChromaDB with images and their embeddings""" | |
# Collect all .jpg files inside subfolders | |
all_images = sorted(glob(os.path.join(img_root, "*", "*.jpg"))) | |
selected_images = all_images[:limit] | |
if len(selected_images) == 0: | |
raise RuntimeError("β No image files found in unzipped structure!") | |
# Get existing IDs | |
existing_ids = set() | |
try: | |
existing_count = collection.count() | |
if existing_count > 0: | |
results = collection.get(limit=existing_count) | |
existing_ids = set(results['ids']) | |
except Exception as e: | |
print(f"Error getting existing IDs: {e}") | |
# Filter out images that are already in the database | |
new_images = [] | |
new_ids = [] | |
new_metadatas = [] | |
for fpath in selected_images: | |
# Create ID from path | |
image_id = fpath.replace('/', '_') | |
if image_id not in existing_ids: | |
new_images.append(fpath) | |
new_ids.append(image_id) | |
name = os.path.splitext(os.path.basename(fpath))[0].replace("_", " ") | |
new_metadatas.append({ | |
"path": fpath, | |
"name": name | |
}) | |
if not new_images: | |
print("β All images are already in the database.") | |
return | |
print(f"π§ Adding {len(new_images)} new images to the database...") | |
# Process images in batches to avoid memory issues | |
batch_size = 50 | |
for i in range(0, len(new_images), batch_size): | |
batch_imgs = new_images[i:i+batch_size] | |
batch_ids = new_ids[i:i+batch_size] | |
batch_metadatas = new_metadatas[i:i+batch_size] | |
print(f"Processing batch {i//batch_size + 1}/{(len(new_images)-1)//batch_size + 1}...") | |
try: | |
collection.add( | |
documents=batch_imgs, # ChromaDB will call our embedding function on these | |
ids=batch_ids, | |
metadatas=batch_metadatas | |
) | |
except Exception as e: | |
print(f"β οΈ Error adding batch to database: {e}") | |
# Count total faces in database | |
total_faces = collection.count() | |
print(f"β Database now contains {total_faces} faces.") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 4: LOAD OPENAI API KEY | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
if not openai.api_key: | |
print("β οΈ OpenAI API key not found. GPT-4 analysis will not work.") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π STEP 5: FACE MATCHING FUNCTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def scan_face(user_image, collection): | |
"""Scan a face image and find matches in the database""" | |
if user_image is None: | |
return [], "", "", "Please upload a face image." | |
try: | |
# Query database for similar faces using the image directly | |
results = collection.query( | |
query_embeddings=None, # Will be generated by our embedding function | |
query_images=[user_image], # Pass the PIL image directly | |
n_results=5, | |
include=["metadatas", "distances"] | |
) | |
metadatas = results.get("metadatas", [[]])[0] | |
distances = results.get("distances", [[]])[0] | |
gallery, captions, names = [], [], [] | |
scores = [] | |
for i, metadata in enumerate(metadatas): | |
try: | |
path = metadata["path"] | |
name = metadata["name"] | |
# Convert distance to similarity score (1 - normalized_distance) | |
# ChromaDB uses cosine distance, so 0 is most similar, 2 is most different | |
distance = distances[i] | |
similarity = 1 - (distance / 2) # Convert to 0-1 scale | |
scores.append(similarity) | |
img = Image.open(path) | |
gallery.append(img) | |
captions.append(f"{name} (Score: {similarity:.2f})") | |
names.append(name) | |
except Exception as e: | |
captions.append(f"β οΈ Error loading match image: {e}") | |
risk_score = min(100, int(np.mean(scores) * 100)) if scores else 0 | |
# π§ GPT-4 EXPLANATION | |
explanation = "" | |
if openai.api_key and names: | |
try: | |
prompt = ( | |
f"The uploaded face matches closely with: {', '.join(names)}. " | |
f"Based on this, should the user be suspicious? Analyze like a funny but smart AI dating detective." | |
) | |
response = openai.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You're a playful but intelligent AI face-matching analyst."}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
explanation = response.choices[0].message.content | |
except Exception as e: | |
explanation = f"(OpenAI error): {e}" | |
else: | |
explanation = "OpenAI API key not set or no matches found." | |
return gallery, "\n".join(captions), f"{risk_score}/100", explanation | |
except Exception as e: | |
return [], "", "", f"Error scanning face: {e}" | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π± STEP 6: ADD NEW FACE FUNCTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def add_new_face(image, name, collection): | |
"""Add a new face to the database""" | |
if image is None or not name: | |
return "Please provide both an image and a name." | |
try: | |
# Save image to a temporary file | |
timestamp = int(time.time()) | |
os.makedirs("uploaded_faces", exist_ok=True) | |
path = f"uploaded_faces/{name.replace(' ', '_')}_{timestamp}.jpg" | |
image.save(path) | |
# Add to ChromaDB | |
image_id = path.replace('/', '_') | |
collection.add( | |
documents=[path], | |
ids=[image_id], | |
metadatas=[{ | |
"path": path, | |
"name": name | |
}] | |
) | |
return f"β Added {name} to the database successfully!" | |
except Exception as e: | |
return f"β Failed to add face: {e}" | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# ποΈ STEP 7: GRADIO UI | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
def create_ui(): | |
"""Create Gradio UI with both scan and add functionality""" | |
# Setup database | |
client, collection = setup_database() | |
if collection is None: | |
raise RuntimeError("β Database setup failed.") | |
# Populate database with initial images | |
populate_database(collection) | |
# Wrapper functions for Gradio that use the database collection | |
def scan_face_wrapper(image): | |
return scan_face(image, collection) | |
def add_face_wrapper(image, name): | |
return add_new_face(image, name, collection) | |
with gr.Blocks(title="Tinder Scanner β Real Face Match Detector") as demo: | |
gr.Markdown("# Tinder Scanner β Real Face Match Detector") | |
gr.Markdown("Scan a face image to find visual matches using CLIP and ChromaDB, and get a cheeky GPT-4 analysis.") | |
with gr.Tab("Scan Face"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload a Face Image") | |
scan_button = gr.Button("π Scan Face") | |
with gr.Column(): | |
gallery = gr.Gallery(label="π Top Matches", columns=[5], height="auto") | |
captions = gr.Textbox(label="Match Names + Similarity Scores") | |
risk_score = gr.Textbox(label="π¨ Cheating Risk Score") | |
explanation = gr.Textbox(label="π§ GPT-4 Explanation", lines=5) | |
scan_button.click( | |
fn=scan_face_wrapper, | |
inputs=[input_image], | |
outputs=[gallery, captions, risk_score, explanation] | |
) | |
with gr.Tab("Add New Face"): | |
with gr.Row(): | |
with gr.Column(): | |
new_image = gr.Image(type="pil", label="Upload New Face Image") | |
new_name = gr.Textbox(label="Person's Name") | |
add_button = gr.Button("β Add to Database") | |
with gr.Column(): | |
result = gr.Textbox(label="Result") | |
add_button.click( | |
fn=add_face_wrapper, | |
inputs=[new_image, new_name], | |
outputs=result | |
) | |
return demo | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# π MAIN EXECUTION | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.launch() |