mdirshad09 commited on
Commit
2519bba
·
1 Parent(s): d74cfc8

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +136 -0
  2. data_extractor.py +50 -0
  3. embeddings_generation.py +12 -0
  4. models.py +21 -0
  5. pinecone.py +46 -0
  6. requirements.txt +77 -0
  7. scrapper.py +140 -0
  8. testing.py +0 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from src.data_extractor import DataExtractor
3
+ from src.models import FaceNetModel
4
+ from src.pinecone import Pinecone
5
+ from src.embeddings_generation import FaceEmbedding
6
+ import streamlit as st, numpy as np, cv2
7
+ from PIL import Image
8
+
9
+
10
+ data_extractor = DataExtractor('./data/Json', './data/Images')
11
+ combined_data, paths = data_extractor.concat_data()
12
+
13
+ model = FaceNetModel()
14
+ mtcnn, resnet = model.initialize_model()
15
+ transform = model.get_transform()
16
+
17
+ embeddings = FaceEmbedding(transform, resnet)
18
+
19
+ pinecone = Pinecone('c984cd49-42a6-4aa0-b2f2-e96cfb8f59bc', 'gcp-starter', 'facenet')
20
+ pinecone_index = pinecone.initialize_index()
21
+
22
+ def process_images():
23
+ count = 0
24
+ for index, image_path in enumerate(paths):
25
+ try:
26
+ img = Image.open(image_path)
27
+ img = img.convert("RGB")
28
+ width, height = img.size
29
+ boxes, _ = mtcnn.detect(img)
30
+
31
+ id = combined_data['id'][index]
32
+ img_url = combined_data['Image_URL'][index]
33
+ page_url = combined_data['Page_URL'][index]
34
+
35
+ if len(boxes) == 1:
36
+ # print(index)
37
+ try:
38
+ face_embedding = embeddings.calculate_face_embedding(img, boxes[0])
39
+ x1, y1, x2, y2 = [int(coord) for coord in boxes[0]]
40
+
41
+ coordinates = [x1/width, y1/height, x2/width, y2/height]
42
+ pinecone.upsert_data(id, face_embedding, image_path, img_url, page_url, coordinates, True)
43
+ except Exception as e:
44
+ print(e)
45
+ continue
46
+ if len(boxes) > 1:
47
+ for box in boxes:
48
+ # print(index)
49
+ try:
50
+ face_embedding = embeddings.calculate_face_embedding(img, box)
51
+ x1, y1, x2, y2 = [int(coord) for coord in box]
52
+ coordinates = [x1/width, y1/height, x2/width, y2/height]
53
+
54
+ ### store data
55
+ pinecone.upsert_data(id, face_embedding, image_path, img_url, page_url, coordinates, False)
56
+ except Exception as e:
57
+ print(e)
58
+ continue
59
+
60
+ except FileNotFoundError:
61
+ print(f"File not found: {image_path}")
62
+
63
+ except OSError:
64
+ print(f"Not an image file or image file is corrupted: {image_path}")
65
+
66
+ except MemoryError:
67
+ print(f"Out of memory when trying to open image: {image_path}")
68
+
69
+ count+=1
70
+ print(count)
71
+
72
+
73
+ def search_images(query_img):
74
+
75
+ boxes, _ = mtcnn.detect(query_img)
76
+ query_embedding = embeddings.calculate_face_embedding(query_img, boxes[0])
77
+ query_embedding = query_embedding.tolist()
78
+
79
+ return pinecone.search_data(query_embedding)
80
+
81
+ def get_image():
82
+ st.title("Image Upload")
83
+
84
+ image_file = st.file_uploader("Upload Image", type=['png', 'jpeg', 'jpg', 'jfif'])
85
+ if image_file is not None:
86
+ image = Image.open(image_file)
87
+ st.image(image, use_column_width=True)
88
+ matches = search_images(image)
89
+
90
+ return matches
91
+
92
+
93
+
94
+ def display_image(image):
95
+ st.image(image, use_column_width=True)
96
+
97
+ def process_matches(matches):
98
+ for match in matches['matches']:
99
+ if match['metadata']['Single Face'] == False:
100
+ img_id = match['metadata']['Image id']
101
+ results = pinecone_index.query(vector = match['values'], top_k = 4, include_values = False, include_metadata = True, filter={'Image id': {'$eq': img_id}})
102
+ path = match['metadata']['directory path']
103
+ image = Image.open(path)
104
+ width, height = image.size
105
+
106
+ for face in results['matches']:
107
+ if face['score'] < 0.9:
108
+ normalized_coordinates = face['metadata']['Face Coordinates']
109
+ normalized_coordinates = [float(item) for item in normalized_coordinates]
110
+
111
+ coordinates = [normalized_coordinates[0] * width, normalized_coordinates[1] * height, normalized_coordinates[2] * width, normalized_coordinates[3] * height]
112
+ x1, y1, x2, y2 = [int(coord) for coord in coordinates]
113
+ face_width = x2 - x1
114
+ face_height = y2 - y1
115
+ face_region = np.array(image.crop(tuple(coordinates)))
116
+
117
+ blurred_face_region = cv2.GaussianBlur(face_region, (99, 99), 20)
118
+ blurred_face_image = Image.fromarray(blurred_face_region)
119
+
120
+ if blurred_face_image.size != (face_width, face_height):
121
+ blurred_face_image = blurred_face_image.resize((face_width, face_height))
122
+
123
+ image.paste(blurred_face_image, (x1, y1))
124
+ display_image(image)
125
+ else:
126
+ path = match['metadata']['directory path']
127
+ img = Image.open(path)
128
+ display_image(img)
129
+
130
+
131
+
132
+ if __name__ == "__main__":
133
+ # process_images()
134
+ matches = get_image()
135
+ if matches is not None:
136
+ process_matches(matches)
data_extractor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import pandas as pd
5
+
6
+ class DataExtractor:
7
+ def __init__(self, json_folder_path, image_root_directory):
8
+ self.json_folder_path = json_folder_path
9
+ self.image_root_directory = image_root_directory
10
+
11
+ def extract_json_data(self):
12
+ extracted_data = []
13
+ for filename in os.listdir(self.json_folder_path):
14
+ if filename.endswith(".json"):
15
+ with open(os.path.join(self.json_folder_path, filename), 'r') as json_file:
16
+ data = json.load(json_file)
17
+ if 'query' in data and 'images' in data:
18
+ query = data['query']
19
+ images = data['images']
20
+ for image_data in images:
21
+ extracted_data.append({
22
+ 'Class': query,
23
+ 'id': image_data['Id'],
24
+ 'Image_URL': image_data['url'],
25
+ 'Title': image_data['title'],
26
+ 'Page_URL': image_data['page_url']
27
+ })
28
+ return pd.DataFrame(extracted_data)
29
+
30
+ def extract_image_paths(self):
31
+ extracted_data = []
32
+ image_files = glob.glob(os.path.join(self.image_root_directory, '**', '*.jpg'), recursive=True)
33
+ for image_file in image_files:
34
+ class_name = os.path.basename(os.path.dirname(image_file))
35
+ id_name = os.path.splitext(os.path.basename(image_file))[0]
36
+ extracted_data.append({
37
+ 'Class': class_name,
38
+ 'id': id_name,
39
+ 'Image_Path': image_file
40
+ })
41
+ return pd.DataFrame(extracted_data)
42
+
43
+ def concat_data(self):
44
+ json_data = self.extract_json_data()
45
+ image_data = self.extract_image_paths()
46
+
47
+ combined_data = pd.merge(json_data, image_data, on=['id'], how='inner')
48
+ paths = combined_data['Image_Path']
49
+ print(paths)
50
+ return combined_data, paths
embeddings_generation.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class FaceEmbedding:
2
+ def __init__(self, transform, resnet):
3
+ self.transform = transform
4
+ self.resnet = resnet
5
+
6
+ def calculate_face_embedding(self, image, box):
7
+ face = image.crop(box)
8
+ face = face.convert("RGB")
9
+ face = self.transform(face)
10
+ face = face.unsqueeze(0)
11
+ face_embedding = self.resnet(face)
12
+ return face_embedding
models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from facenet_pytorch import InceptionResnetV1, MTCNN
2
+ import torchvision.transforms as transforms
3
+
4
+ class FaceNetModel:
5
+ def __init__(self):
6
+ self.mtcnn = None
7
+ self.resnet = None
8
+ self.transform = None
9
+
10
+ def initialize_model(self):
11
+ self.mtcnn = MTCNN()
12
+ self.resnet = InceptionResnetV1(pretrained='vggface2').eval()
13
+ return self.mtcnn, self.resnet
14
+
15
+ def get_transform(self):
16
+ self.transform = transforms.Compose([
17
+ transforms.Resize((250, 250)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20
+ ])
21
+ return self.transform
pinecone.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import pinecone
3
+ from pinecone import PineconeProtocolError
4
+ class Pinecone:
5
+ def __init__(self, api_key, environment, index_name):
6
+ self.api_key = api_key
7
+ self.environment = environment
8
+ self.index_name = index_name
9
+ self.index = None
10
+
11
+ def initialize_index(self):
12
+ pinecone.init(api_key=self.api_key, environment=self.environment)
13
+ self.index = pinecone.Index(self.index_name)
14
+ return self.index
15
+
16
+ def upsert_data(self, img_id, embeddings, path, img_url, page_url, face_coordinates, single_face):
17
+ vec_id = str(uuid.uuid4())
18
+ data = []
19
+ embedding_as_list = embeddings.tolist()
20
+ if face_coordinates is not None:
21
+ coordinates_1d = [str(coord) for coord in face_coordinates]
22
+ metadata = {'Image id': img_id, 'directory path': path, 'Image URL': img_url, 'Page URL': page_url, 'Face Coordinates': coordinates_1d, 'Single Face': single_face}
23
+
24
+ data.append((vec_id, embedding_as_list, metadata))
25
+ self.index.upsert(data)
26
+
27
+ def search_data(self, query_embedding):
28
+ try:
29
+ matches = self.index.query(
30
+ vector=query_embedding,
31
+ top_k=10,
32
+ include_values=True,
33
+ include_metadata = True
34
+ )
35
+ except PineconeProtocolError as e:
36
+ print(f"PineconeProtocolError occurred: {e}")
37
+ pinecone.deinit()
38
+ pinecone.init(api_key= self.api_key,environment=self.environment)
39
+ index = pinecone.Index(self.index_name)
40
+ matches = index.query(
41
+ vector=query_embedding,
42
+ top_k=10,
43
+ include_values=True,
44
+ include_metadata = True
45
+ )
46
+ return matches
requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.1.2
2
+ asgiref==3.7.2
3
+ attrs==23.1.0
4
+ blinker==1.7.0
5
+ cachetools==5.3.2
6
+ certifi==2023.7.22
7
+ cffi==1.16.0
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ Django==4.2.7
12
+ dnspython==2.4.2
13
+ exceptiongroup==1.1.3
14
+ facenet-pytorch==2.5.3
15
+ filelock==3.13.1
16
+ fsspec==2023.10.0
17
+ gitdb==4.0.11
18
+ GitPython==3.1.40
19
+ h11==0.14.0
20
+ idna==3.4
21
+ image==1.5.33
22
+ importlib-metadata==6.8.0
23
+ Jinja2==3.1.2
24
+ jsonschema==4.19.2
25
+ jsonschema-specifications==2023.7.1
26
+ loguru==0.7.2
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==2.1.3
29
+ mdurl==0.1.2
30
+ mpmath==1.3.0
31
+ networkx==3.2.1
32
+ numpy==1.26.1
33
+ opencv-python==4.8.1.78
34
+ outcome==1.3.0.post0
35
+ packaging==23.2
36
+ pandas==2.1.2
37
+ Pillow==10.1.0
38
+ pinecone-client==2.2.4
39
+ protobuf==4.25.0
40
+ pyarrow==14.0.0
41
+ pycparser==2.21
42
+ pydeck==0.8.1b0
43
+ Pygments==2.16.1
44
+ PySocks==1.7.1
45
+ python-dateutil==2.8.2
46
+ pytz==2023.3.post1
47
+ PyYAML==6.0.1
48
+ referencing==0.30.2
49
+ requests==2.31.0
50
+ rich==13.6.0
51
+ rpds-py==0.12.0
52
+ selenium==4.15.2
53
+ six==1.16.0
54
+ smmap==5.0.1
55
+ sniffio==1.3.0
56
+ sortedcontainers==2.4.0
57
+ sqlparse==0.4.4
58
+ streamlit==1.28.1
59
+ sympy==1.12
60
+ tenacity==8.2.3
61
+ toml==0.10.2
62
+ toolz==0.12.0
63
+ torch==2.1.0
64
+ torchvision==0.16.0
65
+ tornado==6.3.3
66
+ tqdm==4.66.1
67
+ trio==0.23.1
68
+ trio-websocket==0.11.1
69
+ typing_extensions==4.8.0
70
+ tzdata==2023.3
71
+ tzlocal==5.2
72
+ urllib3==2.0.7
73
+ validators==0.22.0
74
+ watchdog==3.0.0
75
+ win32-setctime==1.1.0
76
+ wsproto==1.2.0
77
+ zipp==3.17.0
scrapper.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import selenium
2
+ from selenium import webdriver
3
+ from selenium.webdriver.common.by import By
4
+ import time
5
+ import requests
6
+ import os
7
+ import random
8
+ import hashlib
9
+ import json
10
+
11
+ user_agents = [
12
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.1234.56 Safari/537.36",
13
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/101.0.1234.56 Safari/537.36",
14
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Firefox/101.0.1234.56",
15
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/14.1.2",
16
+ ]
17
+
18
+ def fetch_image_data(query: str, max_links_to_fetch: int, wd: webdriver, sleep_between_interactions: int = 5):
19
+ def scroll_to_end(wd):
20
+ wd.execute_script("window.scrollTo(0, document.body.scrollHeight);")
21
+ time.sleep(sleep_between_interactions)
22
+
23
+ search_url = "https://www.google.com/search?safe=off&site=&tbm=isch&source=hp&q={q}&oq={q}&gs_l=img"
24
+
25
+ wd.get(search_url.format(q=query))
26
+
27
+ image_data_list = []
28
+
29
+ image_count = 0
30
+ results_start = 0
31
+
32
+ while image_count < max_links_to_fetch:
33
+ scroll_to_end(wd)
34
+
35
+ # Get all image thumbnail results
36
+ thumbnail_results = wd.find_elements(By.CLASS_NAME, "Q4LuWd")
37
+ number_results = len(thumbnail_results)
38
+
39
+ print(f"Found: {number_results} search results. Extracting links from {results_start}:{number_results}")
40
+ done = False
41
+ for img in thumbnail_results[results_start:number_results]:
42
+ try:
43
+ img.click()
44
+ time.sleep(sleep_between_interactions)
45
+ except Exception:
46
+ continue
47
+
48
+ # Extract image data: URL, title, and dimensions
49
+ actual_images = wd.find_elements(By.CLASS_NAME, 'pT0Scc')
50
+ for actual_image in actual_images:
51
+ print("ACTUAL IMAGE: ", actual_image)
52
+ if actual_image.get_attribute('src') and 'http' in actual_image.get_attribute('src'):
53
+ image_url = actual_image.get_attribute('src')
54
+
55
+ response = requests.get(image_url)
56
+ if response.status_code == 200:
57
+ image_title = actual_image.get_attribute('alt')
58
+
59
+ # Find the parent <a> tag of the image for the page URL
60
+ parent_a_tag = actual_image.find_element(By.XPATH, './ancestor::a')
61
+
62
+ # Get the page URL directly from the parent <a> tag
63
+ image_page_url = parent_a_tag.get_attribute('href')
64
+
65
+ # Create a folder for the specific query if it doesn't exist
66
+ query_folder = os.path.join('images', query)
67
+ if not os.path.exists(query_folder):
68
+ os.makedirs(query_folder)
69
+
70
+ # Generate a unique file name using the URL hash
71
+ file_name = hashlib.sha1(image_url.encode()).hexdigest()[:10]
72
+
73
+ # Create a file path with the .jpg extension
74
+ file_path = os.path.join(query_folder, f"{file_name}.jpg")
75
+ # id = id.split('/')[-1]
76
+ # Save the image
77
+ with open(file_path, 'wb') as f:
78
+ f.write(response.content)
79
+
80
+ print(f"SUCCESS - saved {image_url} - as {file_path}")
81
+
82
+ # Store the metadata in the list
83
+ image_data_list.append({
84
+ "url": image_url,
85
+ "title": image_title,
86
+ "page_url": image_page_url,
87
+ "Id": file_name
88
+ })
89
+
90
+ image_count += 1 # Increment the image count
91
+
92
+ if image_count >= max_links_to_fetch:
93
+ print(f"Found: {len(image_data_list)} images, done!")
94
+ done = True
95
+ break # Exit the loop
96
+ if done:
97
+ break
98
+ if done:
99
+ break
100
+
101
+ # Move the result start point further down
102
+ results_start = len(thumbnail_results)
103
+
104
+ return image_data_list
105
+
106
+ if __name__ == '__main__':
107
+ # Select a random user agent
108
+ selected_user_agent = random.choice(user_agents)
109
+
110
+ # Set the user agent for Edge driver
111
+ options = webdriver.EdgeOptions()
112
+ options.add_argument(f'user-agent={selected_user_agent}')
113
+
114
+ # Initialize the Edge driver with the specified user agent
115
+ wd = webdriver.Edge(options=options)
116
+
117
+ queries = ["Elon Musk", "Barack Obama", "Taylor Swift", "Bill Gates", "Eminem"] # change your set of queries here
118
+
119
+ for query in queries:
120
+ num_of_images = 20
121
+ wd.get('https://google.com')
122
+ search_box = wd.find_element(By.NAME, 'q')
123
+ search_box.send_keys(query)
124
+ image_data_list = fetch_image_data(query, num_of_images, wd)
125
+
126
+ # Create a dictionary to store the image data
127
+ query_image_data = {
128
+ "query": query,
129
+ "images": image_data_list
130
+ }
131
+
132
+ # Serialize the image data dictionary to JSON
133
+ json_data = json.dumps(query_image_data, indent=4)
134
+
135
+ # Save the JSON data to a file with the query name
136
+ json_filename = f"{query}.json"
137
+ with open(json_filename, 'w') as json_file:
138
+ json_file.write(json_data)
139
+
140
+ wd.quit()
testing.py ADDED
File without changes