dp92 commited on
Commit
2dbc820
·
1 Parent(s): 053f91b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -50
app.py CHANGED
@@ -1,50 +1,33 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.models as models
4
- import torchvision.transforms as transforms
5
- import os
6
- from PIL import Image
7
-
8
- # Define the ResNet-50 model
9
- model = models.resnet50(pretrained=True)
10
-
11
- # Remove the classification head (the fully connected layer)
12
- num_features = model.fc.in_features
13
- model.fc = nn.Identity()
14
-
15
- # Set the model to evaluation mode
16
- model.eval()
17
-
18
- # Define the preprocessing transforms
19
- preprocess = transforms.Compose([
20
- transforms.Resize(256),
21
- transforms.CenterCrop(224),
22
- transforms.ToTensor(),
23
- transforms.Normalize(
24
- mean=[0.485, 0.456, 0.406],
25
- std=[0.229, 0.224, 0.225]
26
- )
27
- ])
28
-
29
- # Define the dictionary to store the feature vectors
30
- features = {}
31
-
32
- # Iterate over the images and extract the features
33
- image_dir = 'lfw'
34
- for root, dirs, files in os.walk(image_dir):
35
- for file in files:
36
- # Load the image
37
- image_path = os.path.join(root, file)
38
- image = Image.open(image_path).convert('RGB')
39
-
40
- # Apply the preprocessing transforms
41
- input_tensor = preprocess(image)
42
- input_batch = input_tensor.unsqueeze(0)
43
-
44
- # Extract the features from the penultimate layer
45
- with torch.no_grad():
46
- features_tensor = model(input_batch)
47
- features_vector = torch.squeeze(features_tensor).numpy()
48
-
49
- # Store the feature vector in the dictionary
50
- features[file] = features_vector
 
1
+
2
+
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import cv2
9
+
10
+ # Load the test dataset
11
+ test_data = pd.read_csv("test_dataset.csv")
12
+
13
+ # Create a dropdown to select an image from the test dataset
14
+ selected_image = st.sidebar.selectbox("Select an image", test_data["image"])
15
+
16
+ # Create a file uploader to upload an image
17
+ uploaded_file = st.sidebar.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
18
+
19
+ # Load the selected or uploaded image
20
+ if uploaded_file is not None:
21
+ query_image = cv2.imread(uploaded_file.name)
22
+ else:
23
+ query_image = cv2.imread(selected_image)
24
+
25
+ # Display the query image
26
+ st.image(query_image, caption="Query Image", use_column_width=True)
27
+
28
+ # Use the similarity search system to find the most similar images
29
+ similar_images = find_similar_images(query_image)
30
+
31
+ # Display the most similar images
32
+ for image in similar_images:
33
+ st.image(image, caption="Similar Image", use_column_width=True)