ans123 commited on
Commit
dfe5531
Β·
verified Β·
1 Parent(s): 1cc356f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -185
app.py CHANGED
@@ -8,10 +8,14 @@ import gradio as gr
8
  import openai
9
  from tqdm import tqdm
10
  from glob import glob
11
- import psycopg2
12
- from psycopg2.extras import execute_values
13
  import json
14
  import time
 
 
 
 
15
 
16
  # ─────────────────────────────────────────────
17
  # πŸ“‚ STEP 1: UNZIP TO CORRECT STRUCTURE
@@ -29,93 +33,72 @@ if not os.path.exists(unzip_dir):
29
  img_root = os.path.join(unzip_dir, "lfw-deepfunneled")
30
 
31
  # ─────────────────────────────────────────────
32
- # πŸ—„οΈ STEP 2: DATABASE SETUP
33
- # ─────────────────────────────────────────────
34
- def setup_database():
35
- """Setup PostgreSQL with pgvector extension"""
36
- # Database configuration
37
- DB_CONFIG = {
38
- "dbname": "face_matcher",
39
- "user": "postgres",
40
- "password": "postgres", # Change this to your actual password
41
- "host": "localhost",
42
- "port": "5432"
43
- }
44
-
45
- try:
46
- # Connect to PostgreSQL server to create database if it doesn't exist
47
- conn = psycopg2.connect(
48
- dbname="postgres",
49
- user=DB_CONFIG["user"],
50
- password=DB_CONFIG["password"],
51
- host=DB_CONFIG["host"]
52
- )
53
- conn.autocommit = True
54
- cur = conn.cursor()
55
-
56
- # Create database if it doesn't exist
57
- cur.execute(f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{DB_CONFIG['dbname']}'")
58
- exists = cur.fetchone()
59
- if not exists:
60
- cur.execute(f"CREATE DATABASE {DB_CONFIG['dbname']}")
61
- print(f"Database {DB_CONFIG['dbname']} created.")
62
-
63
- cur.close()
64
- conn.close()
65
-
66
- # Connect to the face_matcher database
67
- conn = psycopg2.connect(**DB_CONFIG)
68
- conn.autocommit = True
69
- cur = conn.cursor()
70
-
71
- # Create pgvector extension if it doesn't exist
72
- cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
73
-
74
- # Create faces table if it doesn't exist
75
- cur.execute("""
76
- CREATE TABLE IF NOT EXISTS faces (
77
- id SERIAL PRIMARY KEY,
78
- path TEXT UNIQUE NOT NULL,
79
- name TEXT NOT NULL,
80
- embedding vector(512),
81
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
82
- )
83
- """)
84
-
85
- # Create index on the embedding column
86
- cur.execute("CREATE INDEX IF NOT EXISTS faces_embedding_idx ON faces USING ivfflat (embedding vector_ip_ops)")
87
-
88
- print("βœ… Database setup complete.")
89
- return conn
90
- except Exception as e:
91
- print(f"❌ Database setup failed: {e}")
92
- return None
93
-
94
- # ─────────────────────────────────────────────
95
- # 🧠 STEP 3: LOAD CLIP MODEL
96
  # ─────────────────────────────────────────────
97
  device = "cuda" if torch.cuda.is_available() else "cpu"
98
  model, preprocess = clip.load("ViT-B/32", device=device)
99
  print(f"βœ… CLIP model loaded on {device}")
100
 
101
  # ─────────────────────────────────────────────
102
- # πŸ“Š STEP 4: EMBEDDING FUNCTIONS
103
  # ─────────────────────────────────────────────
104
- def embed_image(image_path):
105
- """Generate CLIP embedding for a single image"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  try:
107
- img = Image.open(image_path).convert("RGB")
108
- img_input = preprocess(img).unsqueeze(0).to(device)
109
- with torch.no_grad():
110
- emb = model.encode_image(img_input).cpu().numpy().flatten()
111
- emb /= np.linalg.norm(emb)
112
- return emb
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
- print(f"⚠️ Error embedding {image_path}: {e}")
115
- return None
116
 
117
- def populate_database(conn, limit=500):
118
- """Populate database with images and their embeddings"""
119
  # Collect all .jpg files inside subfolders
120
  all_images = sorted(glob(os.path.join(img_root, "*", "*.jpg")))
121
  selected_images = all_images[:limit]
@@ -123,126 +106,141 @@ def populate_database(conn, limit=500):
123
  if len(selected_images) == 0:
124
  raise RuntimeError("❌ No image files found in unzipped structure!")
125
 
126
- cur = conn.cursor()
127
-
128
- # Check which images are already in the database
129
- cur.execute("SELECT path FROM faces")
130
- existing_paths = set(path[0] for path in cur.fetchall())
 
 
 
 
131
 
132
  # Filter out images that are already in the database
133
- new_images = [path for path in selected_images if path not in existing_paths]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if not new_images:
136
  print("βœ… All images are already in the database.")
137
  return
138
 
139
- print(f"🧠 Generating CLIP embeddings for {len(new_images)} new images...")
140
 
141
  # Process images in batches to avoid memory issues
142
  batch_size = 50
143
  for i in range(0, len(new_images), batch_size):
144
- batch = new_images[i:i+batch_size]
145
- data_to_insert = []
 
146
 
147
- for fpath in tqdm(batch, desc=f"Embedding batch {i//batch_size + 1}"):
148
- try:
149
- emb = embed_image(fpath)
150
- if emb is not None:
151
- name = os.path.splitext(os.path.basename(fpath))[0].replace("_", " ")
152
- data_to_insert.append((fpath, name, emb.tolist()))
153
- except Exception as e:
154
- print(f"⚠️ Error with {fpath}: {e}")
155
 
156
- # Insert batch into database
157
- if data_to_insert:
158
- execute_values(
159
- cur,
160
- "INSERT INTO faces (path, name, embedding) VALUES %s ON CONFLICT (path) DO NOTHING",
161
- [(d[0], d[1], d[2]) for d in data_to_insert],
162
- template="(%s, %s, %s::vector)"
163
  )
164
- conn.commit()
 
165
 
166
  # Count total faces in database
167
- cur.execute("SELECT COUNT(*) FROM faces")
168
- total_faces = cur.fetchone()[0]
169
  print(f"βœ… Database now contains {total_faces} faces.")
170
 
171
  # ─────────────────────────────────────────────
172
- # πŸ” STEP 5: LOAD OPENAI API KEY
173
  # ─────────────────────────────────────────────
174
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
175
 
176
  # ─────────────────────────────────────────────
177
- # πŸ” STEP 6: FACE MATCHING FUNCTION
178
  # ─────────────────────────────────────────────
179
- def scan_face(user_image, conn):
180
  """Scan a face image and find matches in the database"""
181
  if user_image is None:
182
  return [], "", "", "Please upload a face image."
183
 
184
  try:
185
- user_image = user_image.convert("RGB")
186
- tensor = preprocess(user_image).unsqueeze(0).to(device)
187
- with torch.no_grad():
188
- query_emb = model.encode_image(tensor).cpu().numpy().flatten()
189
- query_emb /= np.linalg.norm(query_emb)
190
- except Exception as e:
191
- return [], "", "", f"Image preprocessing failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- # Query database for similar faces
194
- cur = conn.cursor()
195
- emb_list = query_emb.tolist()
196
- cur.execute("""
197
- SELECT path, name, embedding <-> %s::vector AS distance
198
- FROM faces
199
- ORDER BY distance
200
- LIMIT 5
201
- """, (emb_list,))
202
-
203
- results = cur.fetchall()
204
-
205
- gallery, captions, names = [], [], []
206
- scores = []
207
-
208
- for path, name, distance in results:
209
- try:
210
- # Convert distance to similarity score (1 - distance)
211
- similarity = 1 - distance
212
- scores.append(similarity)
213
-
214
- img = Image.open(path)
215
- gallery.append(img)
216
- captions.append(f"{name} (Score: {similarity:.2f})")
217
- names.append(name)
218
- except Exception as e:
219
- captions.append(f"⚠️ Error loading match image: {e}")
220
-
221
- risk_score = min(100, int(np.mean(scores) * 100)) if scores else 0
222
 
223
- # 🧠 GPT-4 EXPLANATION
224
- try:
225
- prompt = (
226
- f"The uploaded face matches closely with: {', '.join(names)}. "
227
- f"Based on this, should the user be suspicious? Analyze like a funny but smart AI dating detective."
228
- )
229
- response = openai.chat.completions.create(
230
- model="gpt-4",
231
- messages=[
232
- {"role": "system", "content": "You're a playful but intelligent AI face-matching analyst."},
233
- {"role": "user", "content": prompt}
234
- ]
235
- )
236
- explanation = response.choices[0].message.content
237
  except Exception as e:
238
- explanation = f"(OpenAI error): {e}"
239
-
240
- return gallery, "\n".join(captions), f"{risk_score}/100", explanation
241
 
242
  # ─────────────────────────────────────────────
243
- # 🌱 STEP 7: ADD NEW FACE FUNCTION
244
  # ─────────────────────────────────────────────
245
- def add_new_face(image, name, conn):
246
  """Add a new face to the database"""
247
  if image is None or not name:
248
  return "Please provide both an image and a name."
@@ -254,46 +252,44 @@ def add_new_face(image, name, conn):
254
  path = f"uploaded_faces/{name.replace(' ', '_')}_{timestamp}.jpg"
255
  image.save(path)
256
 
257
- # Generate embedding
258
- emb = embed_image(path)
259
- if emb is None:
260
- return "Failed to generate embedding for the image."
261
-
262
- # Add to database
263
- cur = conn.cursor()
264
- cur.execute(
265
- "INSERT INTO faces (path, name, embedding) VALUES (%s, %s, %s::vector)",
266
- (path, name, emb.tolist())
267
  )
268
- conn.commit()
269
 
270
  return f"βœ… Added {name} to the database successfully!"
271
  except Exception as e:
272
  return f"❌ Failed to add face: {e}"
273
 
274
  # ─────────────────────────────────────────────
275
- # πŸŽ›οΈ STEP 8: GRADIO UI
276
  # ─────────────────────────────────────────────
277
  def create_ui():
278
  """Create Gradio UI with both scan and add functionality"""
279
- # Setup database connection
280
- conn = setup_database()
281
- if conn is None:
282
- raise RuntimeError("❌ Database connection failed. Please check your PostgreSQL installation and pgvector extension.")
283
 
284
  # Populate database with initial images
285
- populate_database(conn)
286
 
287
- # Wrapper functions for Gradio that use the database connection
288
  def scan_face_wrapper(image):
289
- return scan_face(image, conn)
290
 
291
  def add_face_wrapper(image, name):
292
- return add_new_face(image, name, conn)
293
 
294
  with gr.Blocks(title="Tinder Scanner – Real Face Match Detector") as demo:
295
  gr.Markdown("# Tinder Scanner – Real Face Match Detector")
296
- gr.Markdown("Scan a face image to find visual matches using CLIP and PostgreSQL, and get a cheeky GPT-4 analysis.")
297
 
298
  with gr.Tab("Scan Face"):
299
  with gr.Row():
 
8
  import openai
9
  from tqdm import tqdm
10
  from glob import glob
11
+ import chromadb
12
+ from chromadb.utils import embedding_functions
13
  import json
14
  import time
15
+ from dotenv import load_dotenv
16
+
17
+ # Load environment variables from .env file
18
+ load_dotenv()
19
 
20
  # ─────────────────────────────────────────────
21
  # πŸ“‚ STEP 1: UNZIP TO CORRECT STRUCTURE
 
33
  img_root = os.path.join(unzip_dir, "lfw-deepfunneled")
34
 
35
  # ─────────────────────────────────────────────
36
+ # 🧠 STEP 2: LOAD CLIP MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # ─────────────────────────────────────────────
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
  model, preprocess = clip.load("ViT-B/32", device=device)
40
  print(f"βœ… CLIP model loaded on {device}")
41
 
42
  # ─────────────────────────────────────────────
43
+ # πŸ—„οΈ STEP 3: CHROMA DB SETUP & EMBEDDING FUNCTION
44
  # ─────────────────────────────────────────────
45
+ class ClipEmbeddingFunction:
46
+ """Custom embedding function for Chroma DB using CLIP"""
47
+
48
+ def __init__(self, model, preprocess, device):
49
+ self.model = model
50
+ self.preprocess = preprocess
51
+ self.device = device
52
+
53
+ def __call__(self, images):
54
+ """Generate embeddings for a list of image paths"""
55
+ embeddings = []
56
+
57
+ for image_path in images:
58
+ try:
59
+ # Check if the path is a string (for new additions from disk)
60
+ if isinstance(image_path, str) and os.path.exists(image_path):
61
+ img = Image.open(image_path).convert("RGB")
62
+ else:
63
+ # For query images that are already PIL images
64
+ img = image_path.convert("RGB") if hasattr(image_path, 'convert') else image_path
65
+
66
+ img_input = self.preprocess(img).unsqueeze(0).to(self.device)
67
+ with torch.no_grad():
68
+ emb = self.model.encode_image(img_input).cpu().numpy().flatten()
69
+ emb /= np.linalg.norm(emb)
70
+ embeddings.append(emb.tolist())
71
+ except Exception as e:
72
+ print(f"⚠️ Error embedding image: {e}")
73
+ # Return a zero vector as fallback
74
+ embeddings.append([0] * 512)
75
+
76
+ return embeddings
77
+
78
+ def setup_database():
79
+ """Setup ChromaDB with CLIP embedding function"""
80
  try:
81
+ # Create persistent client
82
+ client = chromadb.PersistentClient(path="./chroma_db")
83
+
84
+ # Create custom embedding function
85
+ embedding_function = ClipEmbeddingFunction(model, preprocess, device)
86
+
87
+ # Create or get existing collection
88
+ collection = client.get_or_create_collection(
89
+ name="faces",
90
+ embedding_function=embedding_function,
91
+ metadata={"hnsw:space": "cosine"} # Use cosine similarity
92
+ )
93
+
94
+ print("βœ… ChromaDB setup complete.")
95
+ return client, collection
96
  except Exception as e:
97
+ print(f"❌ Database setup failed: {e}")
98
+ return None, None
99
 
100
+ def populate_database(collection, limit=500):
101
+ """Populate ChromaDB with images and their embeddings"""
102
  # Collect all .jpg files inside subfolders
103
  all_images = sorted(glob(os.path.join(img_root, "*", "*.jpg")))
104
  selected_images = all_images[:limit]
 
106
  if len(selected_images) == 0:
107
  raise RuntimeError("❌ No image files found in unzipped structure!")
108
 
109
+ # Get existing IDs
110
+ existing_ids = set()
111
+ try:
112
+ existing_count = collection.count()
113
+ if existing_count > 0:
114
+ results = collection.get(limit=existing_count)
115
+ existing_ids = set(results['ids'])
116
+ except Exception as e:
117
+ print(f"Error getting existing IDs: {e}")
118
 
119
  # Filter out images that are already in the database
120
+ new_images = []
121
+ new_ids = []
122
+ new_metadatas = []
123
+
124
+ for fpath in selected_images:
125
+ # Create ID from path
126
+ image_id = fpath.replace('/', '_')
127
+ if image_id not in existing_ids:
128
+ new_images.append(fpath)
129
+ new_ids.append(image_id)
130
+ name = os.path.splitext(os.path.basename(fpath))[0].replace("_", " ")
131
+ new_metadatas.append({
132
+ "path": fpath,
133
+ "name": name
134
+ })
135
 
136
  if not new_images:
137
  print("βœ… All images are already in the database.")
138
  return
139
 
140
+ print(f"🧠 Adding {len(new_images)} new images to the database...")
141
 
142
  # Process images in batches to avoid memory issues
143
  batch_size = 50
144
  for i in range(0, len(new_images), batch_size):
145
+ batch_imgs = new_images[i:i+batch_size]
146
+ batch_ids = new_ids[i:i+batch_size]
147
+ batch_metadatas = new_metadatas[i:i+batch_size]
148
 
149
+ print(f"Processing batch {i//batch_size + 1}/{(len(new_images)-1)//batch_size + 1}...")
 
 
 
 
 
 
 
150
 
151
+ try:
152
+ collection.add(
153
+ documents=batch_imgs, # ChromaDB will call our embedding function on these
154
+ ids=batch_ids,
155
+ metadatas=batch_metadatas
 
 
156
  )
157
+ except Exception as e:
158
+ print(f"⚠️ Error adding batch to database: {e}")
159
 
160
  # Count total faces in database
161
+ total_faces = collection.count()
 
162
  print(f"βœ… Database now contains {total_faces} faces.")
163
 
164
  # ─────────────────────────────────────────────
165
+ # πŸ” STEP 4: LOAD OPENAI API KEY
166
  # ─────────────────────────────────────────────
167
  openai.api_key = os.getenv("OPENAI_API_KEY")
168
+ if not openai.api_key:
169
+ print("⚠️ OpenAI API key not found. GPT-4 analysis will not work.")
170
 
171
  # ─────────────────────────────────────────────
172
+ # πŸ” STEP 5: FACE MATCHING FUNCTION
173
  # ─────────────────────────────────────────────
174
+ def scan_face(user_image, collection):
175
  """Scan a face image and find matches in the database"""
176
  if user_image is None:
177
  return [], "", "", "Please upload a face image."
178
 
179
  try:
180
+ # Query database for similar faces using the image directly
181
+ results = collection.query(
182
+ query_embeddings=None, # Will be generated by our embedding function
183
+ query_images=[user_image], # Pass the PIL image directly
184
+ n_results=5,
185
+ include=["metadatas", "distances"]
186
+ )
187
+
188
+ metadatas = results.get("metadatas", [[]])[0]
189
+ distances = results.get("distances", [[]])[0]
190
+
191
+ gallery, captions, names = [], [], []
192
+ scores = []
193
+
194
+ for i, metadata in enumerate(metadatas):
195
+ try:
196
+ path = metadata["path"]
197
+ name = metadata["name"]
198
+
199
+ # Convert distance to similarity score (1 - normalized_distance)
200
+ # ChromaDB uses cosine distance, so 0 is most similar, 2 is most different
201
+ distance = distances[i]
202
+ similarity = 1 - (distance / 2) # Convert to 0-1 scale
203
+ scores.append(similarity)
204
+
205
+ img = Image.open(path)
206
+ gallery.append(img)
207
+ captions.append(f"{name} (Score: {similarity:.2f})")
208
+ names.append(name)
209
+ except Exception as e:
210
+ captions.append(f"⚠️ Error loading match image: {e}")
211
+
212
+ risk_score = min(100, int(np.mean(scores) * 100)) if scores else 0
213
 
214
+ # 🧠 GPT-4 EXPLANATION
215
+ explanation = ""
216
+ if openai.api_key and names:
217
+ try:
218
+ prompt = (
219
+ f"The uploaded face matches closely with: {', '.join(names)}. "
220
+ f"Based on this, should the user be suspicious? Analyze like a funny but smart AI dating detective."
221
+ )
222
+ response = openai.chat.completions.create(
223
+ model="gpt-4",
224
+ messages=[
225
+ {"role": "system", "content": "You're a playful but intelligent AI face-matching analyst."},
226
+ {"role": "user", "content": prompt}
227
+ ]
228
+ )
229
+ explanation = response.choices[0].message.content
230
+ except Exception as e:
231
+ explanation = f"(OpenAI error): {e}"
232
+ else:
233
+ explanation = "OpenAI API key not set or no matches found."
 
 
 
 
 
 
 
 
 
234
 
235
+ return gallery, "\n".join(captions), f"{risk_score}/100", explanation
236
+
 
 
 
 
 
 
 
 
 
 
 
 
237
  except Exception as e:
238
+ return [], "", "", f"Error scanning face: {e}"
 
 
239
 
240
  # ─────────────────────────────────────────────
241
+ # 🌱 STEP 6: ADD NEW FACE FUNCTION
242
  # ─────────────────────────────────────────────
243
+ def add_new_face(image, name, collection):
244
  """Add a new face to the database"""
245
  if image is None or not name:
246
  return "Please provide both an image and a name."
 
252
  path = f"uploaded_faces/{name.replace(' ', '_')}_{timestamp}.jpg"
253
  image.save(path)
254
 
255
+ # Add to ChromaDB
256
+ image_id = path.replace('/', '_')
257
+ collection.add(
258
+ documents=[path],
259
+ ids=[image_id],
260
+ metadatas=[{
261
+ "path": path,
262
+ "name": name
263
+ }]
 
264
  )
 
265
 
266
  return f"βœ… Added {name} to the database successfully!"
267
  except Exception as e:
268
  return f"❌ Failed to add face: {e}"
269
 
270
  # ─────────────────────────────────────────────
271
+ # πŸŽ›οΈ STEP 7: GRADIO UI
272
  # ─────────────────────────────────────────────
273
  def create_ui():
274
  """Create Gradio UI with both scan and add functionality"""
275
+ # Setup database
276
+ client, collection = setup_database()
277
+ if collection is None:
278
+ raise RuntimeError("❌ Database setup failed.")
279
 
280
  # Populate database with initial images
281
+ populate_database(collection)
282
 
283
+ # Wrapper functions for Gradio that use the database collection
284
  def scan_face_wrapper(image):
285
+ return scan_face(image, collection)
286
 
287
  def add_face_wrapper(image, name):
288
+ return add_new_face(image, name, collection)
289
 
290
  with gr.Blocks(title="Tinder Scanner – Real Face Match Detector") as demo:
291
  gr.Markdown("# Tinder Scanner – Real Face Match Detector")
292
+ gr.Markdown("Scan a face image to find visual matches using CLIP and ChromaDB, and get a cheeky GPT-4 analysis.")
293
 
294
  with gr.Tab("Scan Face"):
295
  with gr.Row():