File size: 5,441 Bytes
8bc9a5f
4fe2d7f
 
 
 
 
 
 
 
 
8bc9a5f
5a704e3
4fe2d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b2e538
4fe2d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import streamlit as st
import os
import glob
import numpy as np
import cv2
from deepface import DeepFace
from scipy.spatial.distance import cosine
import matplotlib.pyplot as plt
from PIL import Image
import tempfile
import tensorflow as tf

st.set_page_config(page_title="Celebrity Lookalike Finder", layout="wide")

# Styling
st.markdown("""
    <style>
    .main {
        padding: 2rem;
    }
    .stTitle {
        text-align: center;
    }
    </style>
    """, unsafe_allow_html=True)

# Title
st.title("🌟 Celebrity Lookalike Finder")
st.write("Upload your photo to find your celebrity doppelganger!")

def detect_and_align_face(img_path):
    """Detect face and align it using OpenCV's face detector"""
    try:
        img = cv2.imread(img_path)
        if img is None:
            return None
            
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)
        
        if len(faces) == 0:
            return img
        
        x, y, w, h = faces[0]
        margin = 30
        y = max(0, y - margin)
        h = min(img.shape[0] - y, h + 2*margin)
        x = max(0, x - margin)
        w = min(img.shape[1] - x, w + 2*margin)
        
        face = img[y:y+h, x:x+w]
        face = cv2.resize(face, (224, 224))
        return face
        
    except Exception as e:
        st.error(f"Error in face detection: {str(e)}")
        return img

def extract_features(img_path):
    """Extract features using DeepFace"""
    try:
        embedding = DeepFace.represent(
            img_path=img_path,
            model_name="VGG-Face",
            enforce_detection=False,
            detector_backend="opencv"
        )
        
        if isinstance(embedding, list):
            embedding = embedding[0]
            
        if isinstance(embedding, dict):
            if 'embedding' in embedding:
                return np.array(embedding['embedding'])
            else:
                for value in embedding.values():
                    if isinstance(value, (list, np.ndarray)):
                        return np.array(value).flatten()
        
        if isinstance(embedding, np.ndarray):
            return embedding.flatten()
            
        st.warning(f"Unexpected embedding type: {type(embedding)}")
        return None
        
    except Exception as e:
        st.error(f"Error in feature extraction: {str(e)}")
        return None

@st.cache_data
def build_celebrity_database():
    """Build and cache celebrity database"""
    celebrity_paths = glob.glob('data/*.*')
    
    celebrity_features = []
    celebrity_paths_list = []
    
    progress_bar = st.progress(0)
    status_text = st.empty()
    
    for i, img_path in enumerate(celebrity_paths):
        status_text.text(f"Processing image {i+1}/{len(celebrity_paths)}")
        features = extract_features(img_path)
        if features is not None:
            celebrity_features.append(features)
            celebrity_paths_list.append(img_path)
        progress_bar.progress((i + 1) / len(celebrity_paths))
    
    status_text.text("Database built successfully!")
    return celebrity_features, celebrity_paths_list

def find_matches(user_features, celebrity_features, celebrity_paths, top_n=5):
    """Find celebrity matches"""
    similarities = []
    for celeb_feature in celebrity_features:
        if user_features.shape != celeb_feature.shape:
            continue
        similarity = 1 - cosine(user_features, celeb_feature)
        similarities.append(similarity)
    
    if not similarities:
        st.warning("No valid comparisons could be made")
        return
        
    top_indices = np.argsort(similarities)[-top_n:][::-1]
    
    # Display results in columns
    cols = st.columns(top_n)
    
    for i, (idx, col) in enumerate(zip(top_indices, cols)):
        with col:
            celeb_img = Image.open(celebrity_paths[idx])
            st.image(celeb_img, caption=f"Match {i+1}\nSimilarity: {similarities[idx]:.2%}")

def main():
    # Load celebrity database
    with st.spinner("Building celebrity database..."):
        celebrity_features, celebrity_paths = build_celebrity_database()
    
    # File uploader
    uploaded_file = st.file_uploader("Choose a photo", type=['jpg', 'jpeg', 'png'])
    
    if uploaded_file is not None:
        # Create columns for side-by-side display
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("Your Photo")
            st.image(uploaded_file)
        
        # Process the uploaded image
        with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            tmp_path = tmp_file.name
        
        # Extract features and find matches
        with st.spinner("Finding your celebrity matches..."):
            user_features = extract_features(tmp_path)
            
            if user_features is not None:
                with col2:
                    st.subheader("Your Celebrity Matches")
                    find_matches(user_features, celebrity_features, celebrity_paths)
            else:
                st.error("Could not process the uploaded image. Please try another photo.")
        
        # Clean up
        os.unlink(tmp_path)

if __name__ == "__main__":
    main()