DeepNets commited on
Commit
491acfc
·
verified ·
1 Parent(s): 9cf431c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -144
app.py CHANGED
@@ -1,144 +1,147 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- from glob import glob
6
- import tensorflow as tf
7
- from annoy import AnnoyIndex
8
- from tensorflow import keras
9
-
10
-
11
- def load_image(image_path):
12
- image = tf.io.read_file(image_path)
13
- image = tf.image.decode_jpeg(image, channels=3)
14
- image = tf.image.resize(image, (224, 224))
15
- image = tf.image.convert_image_dtype(image, tf.float32)
16
- image = image/255.
17
- return image.numpy()
18
-
19
-
20
- # Specify Database Path
21
- database_path = './ShoesSubset'
22
-
23
- # Create Example Images
24
- class_names = []
25
- with open('./Shoes-ClassNames.txt', mode='r') as names:
26
- class_names = names.read().split(',')[:-1]
27
-
28
- example_image_paths = [
29
- glob(os.path.join(database_path, name, '*'))[0]
30
- for name in class_names
31
- ]
32
- example_images = [load_image(path) for path in example_image_paths]
33
-
34
- # Load Feature Extractor
35
- feature_extractor_path = './Shoes-FeatureExtractor.keras'
36
- feature_extractor = keras.models.load_model(
37
- feature_extractor_path, compile=False)
38
-
39
- # Load Annoy index
40
- index_path = './ShoesSubset.ann'
41
- annoy_index = AnnoyIndex(256, 'angular')
42
- annoy_index.load(index_path)
43
-
44
-
45
- def similarity_search(
46
- query_image, num_images=5, *_,
47
- feature_extractor=feature_extractor,
48
- annoy_index=annoy_index,
49
- database_path=database_path,
50
- metadata_path='./Shoes.csv'
51
- ):
52
-
53
- if np.max(query_image) == 255:
54
- query_image = query_image/255.
55
-
56
- query_vector = feature_extractor.predict(
57
- query_image[np.newaxis, ...], verbose=0)[0]
58
-
59
- # Compute nearest neighbors
60
- nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
61
-
62
- # Load metadata
63
- metadata = pd.read_csv(metadata_path, index_col=0)
64
- metadata = metadata.iloc[nearest_neighbors]
65
- closest_class = metadata.class_name.values[0]
66
-
67
- # Similar Images
68
- similar_images = [
69
- load_image(os.path.join(database_path, class_name, file_name))
70
- for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
71
- ]
72
-
73
- image_gallery = gr.Gallery(
74
- value=similar_images,
75
- label='Similar Images',
76
- object_fit='fill',
77
- preview=True,
78
- visible=True,
79
- height='50vh'
80
- )
81
- return closest_class, image_gallery
82
-
83
-
84
- # Gradio Application
85
- with gr.Blocks(theme='soft') as app:
86
-
87
- gr.Markdown("# Shoes - Content Based Image Retrieval (CBIR)")
88
- gr.Markdown(
89
- f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
90
- gr.Markdown(
91
- "Disclaimer:- Model might suggest incorrect images, try using a different image.")
92
-
93
- with gr.Row(equal_height=True):
94
- # Image Input
95
- query_image = gr.Image(
96
- label='Query Image',
97
- sources=['upload', 'clipboard'],
98
- height='50vh'
99
- )
100
-
101
- # Output Gallery Display
102
- output_gallery = gr.Gallery(visible=False)
103
-
104
- with gr.Row(equal_height=True):
105
-
106
- # Predicted Class
107
- pred_class = gr.Textbox(
108
- label='Predicted Class', placeholder='Let the model think!!...')
109
-
110
- # Number of images to search
111
- n_images = gr.Slider(
112
- value=10,
113
- label='Number of images to search',
114
- minimum=1,
115
- maximum=99,
116
- step=1
117
- )
118
-
119
- # Search Button
120
- search_btn = gr.Button('Search')
121
-
122
- # Example Images
123
- examples = gr.Examples(
124
- examples=example_images,
125
- inputs=query_image,
126
- label='Something similar to me??',
127
- )
128
-
129
- # Input - On Change
130
- query_image.change(
131
- fn=similarity_search,
132
- inputs=[query_image, n_images],
133
- outputs=[pred_class, output_gallery]
134
- )
135
-
136
- # Search - On Click
137
- search_btn.click(
138
- fn=similarity_search,
139
- inputs=[query_image, n_images],
140
- outputs=[pred_class, output_gallery]
141
- )
142
-
143
- if __name__ == '__main__':
144
- app.launch()
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from glob import glob
6
+ import tensorflow as tf
7
+ from annoy import AnnoyIndex
8
+ from tensorflow import keras
9
+
10
+
11
+ def load_image(image_path):
12
+ image = tf.io.read_file(image_path)
13
+ image = tf.image.decode_jpeg(image, channels=3)
14
+ image = tf.image.resize(image, (224, 224))
15
+ image = tf.image.convert_image_dtype(image, tf.float32)
16
+ image = image/255.
17
+ return image.numpy()
18
+
19
+
20
+ # Specify Database Path
21
+ database_path = './ShoesSubset'
22
+
23
+ # Create Example Images
24
+ class_names = []
25
+ with open('./Shoes-ClassNames.txt', mode='r') as names:
26
+ class_names = names.read().split(',')[:-1]
27
+
28
+ example_image_paths = [
29
+ glob(os.path.join(database_path, name, '*'))[0]
30
+ for name in class_names
31
+ ]
32
+ example_images = [load_image(path) for path in example_image_paths]
33
+
34
+ # Load Feature Extractor
35
+ feature_extractor_path = './Shoes-FeatureExtractor.keras'
36
+ feature_extractor = keras.models.load_model(
37
+ feature_extractor_path, compile=False)
38
+
39
+ # Load Annoy index
40
+ index_path = './ShoesSubset.ann'
41
+ annoy_index = AnnoyIndex(256, 'angular')
42
+ annoy_index.load(index_path)
43
+
44
+
45
+ def similarity_search(
46
+ query_image, num_images=5, *_,
47
+ feature_extractor=feature_extractor,
48
+ annoy_index=annoy_index,
49
+ database_path=database_path,
50
+ metadata_path='./Shoes.csv'
51
+ ):
52
+
53
+ if np.max(query_image) == 255:
54
+ query_image = query_image/255.
55
+
56
+ query_vector = feature_extractor.predict(
57
+ query_image[np.newaxis, ...], verbose=0)[0]
58
+
59
+ # Compute nearest neighbors
60
+ nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
61
+
62
+ # Load metadata
63
+ metadata = pd.read_csv(metadata_path, index_col=0)
64
+ metadata = metadata.iloc[nearest_neighbors]
65
+ closest_class = metadata.class_name.values[0]
66
+
67
+ similar_images_paths = [
68
+ os.path.join(database_path, class_name, file_name)
69
+ for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
70
+ ]
71
+ similar_images = [load_image(img) for img in similar_images_paths]
72
+
73
+ image_gallery = gr.Gallery(
74
+ value=similar_images,
75
+ label='Similar Images',
76
+ object_fit='fill',
77
+ preview=True,
78
+ visible=True,
79
+ height='50vh'
80
+ )
81
+ return closest_class, image_gallery, similar_images_paths
82
+
83
+
84
+ # Gradio Application
85
+ with gr.Blocks(theme='soft') as app:
86
+
87
+ gr.Markdown("# Shoes - Content Based Image Retrieval (CBIR)")
88
+ gr.Markdown(
89
+ f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
90
+ gr.Markdown(
91
+ "Disclaimer:- Model might suggest incorrect images, try using a different image.")
92
+
93
+ with gr.Row(equal_height=True):
94
+ # Image Input
95
+ query_image = gr.Image(
96
+ label='Query Image',
97
+ sources=['upload', 'clipboard'],
98
+ height='50vh'
99
+ )
100
+
101
+ # Output Gallery Display
102
+ output_gallery = gr.Gallery(visible=False)
103
+
104
+ # Hidden output for similar images paths
105
+ similar_paths_output = gr.Textbox(visible=False)
106
+
107
+ with gr.Row(equal_height=True):
108
+
109
+ # Predicted Class
110
+ pred_class = gr.Textbox(
111
+ label='Predicted Class', placeholder='Let the model think!!...')
112
+
113
+ # Number of images to search
114
+ n_images = gr.Slider(
115
+ value=10,
116
+ label='Number of images to search',
117
+ minimum=1,
118
+ maximum=99,
119
+ step=1
120
+ )
121
+
122
+ # Search Button
123
+ search_btn = gr.Button('Search')
124
+
125
+ # Example Images
126
+ examples = gr.Examples(
127
+ examples=example_images,
128
+ inputs=query_image,
129
+ label='Something similar to me??',
130
+ )
131
+
132
+ # Input - On Change
133
+ query_image.change(
134
+ fn=similarity_search,
135
+ inputs=[query_image, n_images],
136
+ outputs=[pred_class, output_gallery, similar_paths_output]
137
+ )
138
+
139
+ # Search - On Click
140
+ search_btn.click(
141
+ fn=similarity_search,
142
+ inputs=[query_image, n_images],
143
+ outputs=[pred_class, output_gallery, similar_paths_output]
144
+ )
145
+
146
+ if __name__ == '__main__':
147
+ app.launch()