dp92 commited on
Commit
7f2b08c
·
1 Parent(s): 8d7e7f9

Update one.py

Browse files
Files changed (1) hide show
  1. one.py +47 -69
one.py CHANGED
@@ -1,72 +1,50 @@
 
 
 
 
1
  import os
2
- import numpy as np
3
  from PIL import Image
4
- import tensorflow as tf
5
- from tensorflow.keras.applications.resnet50 import ResNet50
6
 
7
- # Set the path to the dataset
8
- data_path = '/content/lfw/'
9
-
10
- # Load the dataset
11
- images = []
12
- labels = []
13
-
14
- for folder_name in os.listdir(data_path):
15
- folder_path = os.path.join(data_path, folder_name)
16
- if not os.path.isdir(folder_path):
17
- continue
18
- for file_name in os.listdir(folder_path):
19
- file_path = os.path.join(folder_path, file_name)
20
- if not file_path.endswith('.jpg'):
21
- continue
22
- image = np.array(Image.open(file_path).convert('RGB'))
23
- label = folder_name
24
- images.append(image)
25
- labels.append(label)
26
-
27
- # Convert to numpy arrays
28
- images = np.array(images)
29
- labels = np.array(labels)
30
-
31
- # Perform necessary preprocessing on the images
32
- preprocessed_images = tf.keras.applications.resnet50.preprocess_input(images)
33
-
34
- # Obtain a ResNet50 model pre-trained on ImageNet
35
- model = ResNet50(include_top=False, pooling='avg')
36
-
37
- # Extract features from the penultimate layer of the network
38
- features = model.predict(preprocessed_images)
39
-
40
- # Store the features in a dictionary
41
- features_dict = {}
42
- for i in range(len(labels)):
43
- features_dict[labels[i]] = features[i]
44
-
45
- # Use a nearest neighbor algorithm to obtain the 10 most similar images to each query image
46
- from sklearn.neighbors import NearestNeighbors
47
-
48
- # Initialize the nearest neighbor algorithm with cosine distance
49
- nn = NearestNeighbors(n_neighbors=10, metric='cosine')
50
-
51
- # Fit the algorithm to the features
52
- nn.fit(list(features_dict.values()))
53
-
54
- # Define a function to retrieve the most similar images to a query image
55
- def retrieve_similar_images(query_image_path):
56
- # Load the query image
57
- query_image = np.array(Image.open(query_image_path).convert('RGB'))
58
-
59
- # Perform necessary preprocessing on the query image
60
- preprocessed_query_image = tf.keras.applications.resnet50.preprocess_input(np.array([query_image]))
61
-
62
- # Extract features from the query image
63
- query_features = model.predict(preprocessed_query_image)
64
-
65
- # Use the nearest neighbor algorithm to retrieve the most similar images
66
- distances, indices = nn.kneighbors(query_features)
67
-
68
- # Display the most similar images
69
- for i in range(len(indices[0])):
70
- image_path = list(features_dict.keys())[list(features_dict.values()).index(features[indices[0][i]])]
71
- image = Image.open(os.path.join(data_path, image_path)).convert('RGB')
72
- image.show()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import torchvision.transforms as transforms
5
  import os
 
6
  from PIL import Image
 
 
7
 
8
+ # Define the ResNet-50 model
9
+ model = models.resnet50(pretrained=True)
10
+
11
+ # Remove the classification head (the fully connected layer)
12
+ num_features = model.fc.in_features
13
+ model.fc = nn.Identity()
14
+
15
+ # Set the model to evaluation mode
16
+ model.eval()
17
+
18
+ # Define the preprocessing transforms
19
+ preprocess = transforms.Compose([
20
+ transforms.Resize(256),
21
+ transforms.CenterCrop(224),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ # Define the dictionary to store the feature vectors
30
+ features = {}
31
+
32
+ # Iterate over the images and extract the features
33
+ image_dir = 'lfw'
34
+ for root, dirs, files in os.walk(image_dir):
35
+ for file in files:
36
+ # Load the image
37
+ image_path = os.path.join(root, file)
38
+ image = Image.open(image_path).convert('RGB')
39
+
40
+ # Apply the preprocessing transforms
41
+ input_tensor = preprocess(image)
42
+ input_batch = input_tensor.unsqueeze(0)
43
+
44
+ # Extract the features from the penultimate layer
45
+ with torch.no_grad():
46
+ features_tensor = model(input_batch)
47
+ features_vector = torch.squeeze(features_tensor).numpy()
48
+
49
+ # Store the feature vector in the dictionary
50
+ features[file] = features_vector