TheProjectsGuy commited on
Commit
61364af
·
1 Parent(s): 9543bb3

Uploaded normal app layout (without extractor)

Browse files
Files changed (4) hide show
  1. app.py +288 -0
  2. packages.txt +2 -0
  3. requirements.txt +11 -0
  4. utilities.py +478 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Show VLAD clustering for set of example images or a user image
2
+ """
3
+ User input:
4
+ - Domain: Indoor, Aerial, or Urban
5
+ - Image: Image to be clustered
6
+ - Cluster numbers (to visualize)
7
+ - Pixel coordinates (to pick further clusters)
8
+ - A unique cache ID (to store the DINO forward passes)
9
+
10
+ There are example images for each domain.
11
+
12
+ Output:
13
+ - All images with cluster assignments
14
+
15
+ Some Gradio links:
16
+ - Controlling layout
17
+ - https://www.gradio.app/guides/quickstart#blocks-more-flexibility-and-control
18
+ - Data state (persistence)
19
+ - https://www.gradio.app/guides/interface-state
20
+ - https://www.gradio.app/docs/state
21
+ - Layout control
22
+ - https://www.gradio.app/guides/controlling-layout
23
+ - https://www.gradio.app/guides/blocks-and-event-listeners
24
+ """
25
+
26
+ # %%
27
+ import os
28
+ import gradio as gr
29
+ import numpy as np
30
+ import cv2 as cv
31
+ import torch
32
+ from torch import nn
33
+ from torch.nn import functional as F
34
+ from torchvision import transforms as tvf
35
+ from torchvision.transforms import functional as T
36
+ from PIL import Image
37
+ import matplotlib.pyplot as plt
38
+ import distinctipy as dipy
39
+ from typing import Literal, List
40
+ import gradio as gr
41
+ import time
42
+ import glob
43
+ import shutil
44
+ from copy import deepcopy
45
+ # DINOv2 imports
46
+ from utilities import DinoV2ExtractFeatures
47
+ from utilities import VLAD
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ # %%
52
+ # Configurations
53
+ T1 = Literal["query", "key", "value", "token"]
54
+ T2 = Literal["aerial", "indoor", "urban"]
55
+ DOMAINS = ["aerial", "indoor", "urban"]
56
+ T3 = Literal["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14",
57
+ "dinov2_vitg14"]
58
+ _ex = lambda x: os.path.realpath(os.path.expanduser(x))
59
+ dino_model: T3 = "dinov2_vitg14"
60
+ desc_layer: int = 31
61
+ desc_facet: T1 = "value"
62
+ num_c: int = 8
63
+ cache_dir: str = _ex("./cache") # Directory containing program cache
64
+ max_img_size: int = 1024 # Image resolution (max dim/size)
65
+ max_num_imgs: int = 10 # Max number of images to upload
66
+ share: bool = False # Share application using .gradio link
67
+
68
+ # Verify inputs
69
+ assert os.path.isdir(cache_dir), "Cache directory not found"
70
+
71
+ # %%
72
+ # Model and transforms
73
+ print("Loading DINO model")
74
+ # extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet,
75
+ # device=device)
76
+ extractor = None
77
+ print("DINO model loaded")
78
+ # VLAD path (directory)
79
+ ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
80
+ vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
81
+ # Base image transformations
82
+ base_tf = tvf.Compose([
83
+ tvf.ToTensor(),
84
+ tvf.Normalize(mean=[0.485, 0.456, 0.406],
85
+ std=[0.229, 0.224, 0.225])
86
+ ])
87
+
88
+
89
+ # %%
90
+ # Get VLAD object
91
+ def get_vlad_clusters(domain, pr = gr.Progress()):
92
+ dm: T2 = str(domain).lower()
93
+ assert dm in DOMAINS, "Invalid domain"
94
+ # Load VLAD cluster centers
95
+ pr(0, desc="Loading VLAD clusters")
96
+ c_centers_file = os.path.join(vc_dir, dm, "c_centers.pt")
97
+ if not os.path.isfile(c_centers_file):
98
+ return f"Cluster centers not found for: {domain}", None
99
+ c_centers = torch.load(c_centers_file)
100
+ pr(0.5)
101
+ num_c = c_centers.shape[0]
102
+ desc_dim = c_centers.shape[1]
103
+ vlad = VLAD(num_c, desc_dim,
104
+ cache_dir=os.path.dirname(c_centers_file))
105
+ vlad.fit(None) # Restore the cache
106
+ pr(1)
107
+ return f"VLAD clusters loaded for: {domain}", vlad
108
+
109
+
110
+ # %%
111
+ # Get VLAD descriptors
112
+ @torch.no_grad()
113
+ def get_descs(imgs_batch, pr = gr.Progress()):
114
+ imgs_batch: List[np.ndarray] = imgs_batch
115
+ pr(0, desc="Extracting descriptors")
116
+ patch_descs = []
117
+ for i, img in enumerate(imgs_batch):
118
+ # Convert to PIL image
119
+ pil_img = Image.fromarray(img)
120
+ img_pt = base_tf(pil_img).to(device)
121
+ if max(img_pt.shape[-2:]) > max_img_size:
122
+ print(f"Image {i+1}: {img_pt.shape[-2:]}, outside")
123
+ c, h, w = img_pt.shape
124
+ # Maintain aspect ratio
125
+ if h == max(img_pt.shape[-2:]):
126
+ w = int(w * max_img_size / h)
127
+ h = max_img_size
128
+ else:
129
+ h = int(h * max_img_size / w)
130
+ w = max_img_size
131
+ img_pt = T.resize(img_pt, (h, w),
132
+ interpolation=T.InterpolationMode.BICUBIC)
133
+ pil_img = pil_img.resize((w, h)) # Backup
134
+ # Make image patchable
135
+ c, h, w = img_pt.shape
136
+ h_new, w_new = (h // 14) * 14, (w // 14) * 14
137
+ img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...]
138
+ # Extract descriptors
139
+ ret = extractor(img_pt).cpu() # [1, n_p, d]
140
+ patch_descs.append({"img": pil_img, "descs": ret})
141
+ pr((i+1) / len(imgs_batch))
142
+ return patch_descs, \
143
+ f"Descriptors extracted for {len(imgs_batch)} images"
144
+
145
+
146
+ # %%
147
+ # Assign VLAD clusters (descriptor assignment)
148
+ def assign_vlad(patch_descs, vlad, pr = gr.Progress()):
149
+ vlad: VLAD = vlad
150
+ img_patch_descs = [pd["descs"] for pd in patch_descs]
151
+ pr(0, desc="Assigning VLAD clusters")
152
+ desc_assignments = [] # List[Tensor;shape=('h', 'w');int]
153
+ for i, qu_desc in enumerate(img_patch_descs):
154
+ # Residual vectors; 'n' could differ (based on img sizes)
155
+ res = vlad.generate_res_vec(qu_desc[0]) # ['n', n_c, d]
156
+ img = patch_descs[i]["img"]
157
+ h, w, c = np.array(img).shape
158
+ h_p, w_p = h // 14, w // 14
159
+ h_new, w_new = h_p * 14, w_p * 14
160
+ assert h_p * w_p == res.shape[0], "Residual incorrect!"
161
+ # Descriptor assignments
162
+ da = res.abs().sum(dim=2).argmin(dim=1).reshape(h_p, w_p)
163
+ da = F.interpolate(da[None, None, ...].to(float),
164
+ (h_new, w_new), mode="nearest")[0, 0].to(da.dtype)
165
+ desc_assignments.append(da)
166
+ pr((i+1) / len(img_patch_descs))
167
+ pr(1.0)
168
+ return desc_assignments, "VLAD clusters assigned"
169
+
170
+
171
+ # %%
172
+ # Cluster assignments to images
173
+ def get_ca_images(desc_assignments, patch_descs, alpha,
174
+ pr = gr.Progress()):
175
+ if desc_assignments is None or len(desc_assignments) == 0:
176
+ return None, "First load images"
177
+ c_colors = dipy.get_colors(num_c, rng=928,
178
+ colorblind_type="Deuteranomaly")
179
+ np_colors = (np.array(c_colors) * 255).astype(np.uint8)
180
+ # Get images with clusters
181
+ pil_imgs = [pd["img"] for pd in patch_descs]
182
+ res_imgs = [] # List[PIL.Image]
183
+ pr(0, desc="Generating cluster assignment images")
184
+ for i, pil_img in enumerate(pil_imgs):
185
+ # Descriptor assignment image: [h, w, 3]
186
+ da: torch.Tensor = desc_assignments[i] # ['h', 'w']
187
+ da_img = np.zeros((*da.shape, 3), dtype=np.uint8)
188
+ for c in range(num_c):
189
+ da_img[da == c] = np_colors[c]
190
+ # Background image: [h, w, 3]
191
+ img_np = np.array(pil_img, dtype=np.uint8)
192
+ h, w, c = np.array(img_np).shape
193
+ h_p, w_p = (h // 14), (w // 14)
194
+ h_new, w_new = h_p * 14, w_p * 14
195
+ img_np = F.interpolate(torch.tensor(img_np)\
196
+ .permute(2, 0, 1)[None, ...], (h_new, w_new),
197
+ mode='nearest')[0].permute(1, 2, 0).numpy()
198
+ res_img = cv.addWeighted(img_np, 1 - alpha, da_img, alpha, 0.)
199
+ res_imgs.append(Image.fromarray(res_img))
200
+ pr((i+1) / len(pil_imgs))
201
+ pr(1.0)
202
+ return res_imgs, "Cluster assignment images generated"
203
+
204
+
205
+ # %%
206
+ print("Interface build started")
207
+ # Build the interface
208
+ with gr.Blocks() as demo:
209
+ # ---- Helper functions ----
210
+ # Variable number of input images
211
+ def var_num_img(s):
212
+ n = int(s) # Slider value as int
213
+ return [gr.Image.update(label=f"Image {i+1}", visible=True) \
214
+ for i in range(n)] + [gr.Image.update(visible=False) \
215
+ for _ in range(max_num_imgs - n)]
216
+
217
+ # ---- State declarations ----
218
+ vlad = gr.State() # VLAD object
219
+ desc_assignments = gr.State() # Cluster assignments
220
+ imgs_batch = gr.State() # Images as batch
221
+ patch_descs = gr.State() # Patch descriptors
222
+
223
+ # ---- All UI elements ----
224
+ d_vals = [k.title() for k in DOMAINS]
225
+ domain = gr.Radio(d_vals, value=d_vals[0])
226
+ nimg_s = gr.Slider(1, max_num_imgs, value=1, step=1,
227
+ label="How many images?") # How many images?
228
+ with gr.Row(): # Dynamic row (images in columns)
229
+ imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
230
+ for i in range(nimg_s.value)] + \
231
+ [gr.Image(visible=False) \
232
+ for _ in range(max_num_imgs - nimg_s.value)]
233
+ for i, img in enumerate(imgs): # Set image as "input"
234
+ img.change(lambda _: None, img)
235
+ with gr.Row(): # Dynamic row of output (cluster) images
236
+ imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}",
237
+ visible=False) for i in range(max_num_imgs)]
238
+ nimg_s.change(var_num_img, nimg_s, imgs)
239
+ blend_alpha = gr.Slider(0, 1, 0.4, step=0.01, # Cluster centers
240
+ label="Blend alpha (weight for cluster centers)")
241
+ bttn1 = gr.Button("Click Me!") # Cluster assignment
242
+ out_msg1 = gr.Markdown("Select domain and upload images")
243
+ out_msg2 = gr.Markdown("For descriptor extraction")
244
+ out_msg3 = gr.Markdown("Followed by VLAD assignment")
245
+ out_msg4 = gr.Markdown("Followed by cluster images")
246
+
247
+ # ---- Utility functions ----
248
+ # A wrapper to batch the images
249
+ def batch_images(data):
250
+ sv = data[nimg_s]
251
+ images: List[np.ndarray] = [data[imgs[k]] \
252
+ for k in range(sv)]
253
+ return images
254
+ # A wrapper to unbatch images (and pad to max)
255
+ def unbatch_images(imgs_batch):
256
+ ret = [gr.Image.update(visible=False) \
257
+ for _ in range(max_num_imgs)]
258
+ if imgs_batch is None or len(imgs_batch) == 0:
259
+ return ret
260
+ for i, img_pil in enumerate(imgs_batch):
261
+ img_np = np.array(img_pil)
262
+ ret[i] = gr.Image.update(img_np, visible=True)
263
+ return ret
264
+
265
+ # ---- Main pipeline ----
266
+ # Get the VLAD cluster assignment images on click
267
+ bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
268
+ .then(batch_images, {nimg_s, *imgs, imgs_batch}, imgs_batch)\
269
+ .then(get_descs, imgs_batch, [patch_descs, out_msg2])\
270
+ .then(assign_vlad, [patch_descs, vlad],
271
+ [desc_assignments, out_msg3])\
272
+ .then(get_ca_images,
273
+ [desc_assignments, patch_descs, blend_alpha],
274
+ [imgs_batch, out_msg4])\
275
+ .then(unbatch_images, imgs_batch, imgs2)
276
+ # If the blending changes now, update the cluster images
277
+ blend_alpha.change(get_ca_images,
278
+ [desc_assignments, patch_descs, blend_alpha],
279
+ [imgs_batch, out_msg4])\
280
+ .then(unbatch_images, imgs_batch, imgs2)
281
+
282
+ print("Interface build completed")
283
+
284
+
285
+ # %%
286
+ # Deploy application
287
+ demo.queue().launch(share=share)
288
+ print("Application deployment ended, exiting...")
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python3-opencv
2
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ torch
4
+ torchvision
5
+ torchaudio
6
+ pillow
7
+ matplotlib
8
+ distinctipy
9
+ einops
10
+ fast_pytorch_kmeans
11
+
utilities.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A portable utility module for the demo programs
2
+
3
+
4
+ # %%
5
+ import os
6
+ import numpy as np
7
+ import einops as ein
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import fast_pytorch_kmeans as fpk
12
+ from typing import Literal, Union, List
13
+
14
+
15
+ # %%
16
+ # Extract features from a Dino-v2 model
17
+ _DINO_V2_MODELS = Literal["dinov2_vits14", "dinov2_vitb14", \
18
+ "dinov2_vitl14", "dinov2_vitg14"]
19
+ _DINO_FACETS = Literal["query", "key", "value", "token"]
20
+ class DinoV2ExtractFeatures:
21
+ """
22
+ Extract features from an intermediate layer in Dino-v2
23
+ """
24
+ def __init__(self, dino_model: _DINO_V2_MODELS, layer: int,
25
+ facet: _DINO_FACETS="token", use_cls=False,
26
+ norm_descs=True, device: str = "cpu") -> None:
27
+ """
28
+ Parameters:
29
+ - dino_model: The DINO-v2 model to use
30
+ - layer: The layer to extract features from
31
+ - facet: "query", "key", or "value" for the attention
32
+ facets. "token" for the output of the layer.
33
+ - use_cls: If True, the CLS token (first item) is also
34
+ included in the returned list of descriptors.
35
+ Otherwise, only patch descriptors are used.
36
+ - norm_descs: If True, the descriptors are normalized
37
+ - device: PyTorch device to use
38
+ """
39
+ self.vit_type: str = dino_model
40
+ self.dino_model: nn.Module = torch.hub.load(
41
+ 'facebookresearch/dinov2', dino_model)
42
+ self.device = torch.device(device)
43
+ self.dino_model = self.dino_model.eval().to(self.device)
44
+ self.layer: int = layer
45
+ self.facet = facet
46
+ if self.facet == "token":
47
+ self.fh_handle = self.dino_model.blocks[self.layer].\
48
+ register_forward_hook(
49
+ self._generate_forward_hook())
50
+ else:
51
+ self.fh_handle = self.dino_model.blocks[self.layer].\
52
+ attn.qkv.register_forward_hook(
53
+ self._generate_forward_hook())
54
+ self.use_cls = use_cls
55
+ self.norm_descs = norm_descs
56
+ # Hook data
57
+ self._hook_out = None
58
+
59
+ def _generate_forward_hook(self):
60
+ def _forward_hook(module, inputs, output):
61
+ self._hook_out = output
62
+ return _forward_hook
63
+
64
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Parameters:
67
+ - img: The input image
68
+ """
69
+ with torch.no_grad():
70
+ res = self.dino_model(img)
71
+ if self.use_cls:
72
+ res = self._hook_out
73
+ else:
74
+ res = self._hook_out[:, 1:, ...]
75
+ if self.facet in ["query", "key", "value"]:
76
+ d_len = res.shape[2] // 3
77
+ if self.facet == "query":
78
+ res = res[:, :, :d_len]
79
+ elif self.facet == "key":
80
+ res = res[:, :, d_len:2*d_len]
81
+ else:
82
+ res = res[:, :, 2*d_len:]
83
+ if self.norm_descs:
84
+ res = F.normalize(res, dim=-1)
85
+ self._hook_out = None # Reset the hook
86
+ return res
87
+
88
+ def __del__(self):
89
+ self.fh_handle.remove()
90
+
91
+
92
+ # %%
93
+ # VLAD global descriptor implementation
94
+ class VLAD:
95
+ """
96
+ An implementation of VLAD algorithm given database and query
97
+ descriptors.
98
+
99
+ Constructor arguments:
100
+ - num_clusters: Number of cluster centers for VLAD
101
+ - desc_dim: Descriptor dimension. If None, then it is
102
+ inferred when running `fit` method.
103
+ - intra_norm: If True, intra normalization is applied
104
+ when constructing VLAD
105
+ - norm_descs: If True, the given descriptors are
106
+ normalized before training and predicting
107
+ VLAD descriptors. Different from the
108
+ `intra_norm` argument.
109
+ - dist_mode: Distance mode for KMeans clustering for
110
+ vocabulary (not residuals). Must be in
111
+ {'euclidean', 'cosine'}.
112
+ - vlad_mode: Mode for descriptor assignment (to cluster
113
+ centers) in VLAD generation. Must be in
114
+ {'soft', 'hard'}
115
+ - soft_temp: Temperature for softmax (if 'vald_mode' is
116
+ 'soft') for assignment
117
+ - cache_dir: Directory to cache the VLAD vectors. If
118
+ None, then no caching is done. If a str,
119
+ then it is assumed as the folder path. Use
120
+ absolute paths.
121
+
122
+ Notes:
123
+ - Arandjelovic, Relja, and Andrew Zisserman. "All about VLAD."
124
+ Proceedings of the IEEE conference on Computer Vision and
125
+ Pattern Recognition. 2013.
126
+ """
127
+ def __init__(self, num_clusters: int,
128
+ desc_dim: Union[int, None]=None,
129
+ intra_norm: bool=True, norm_descs: bool=True,
130
+ dist_mode: str="cosine", vlad_mode: str="hard",
131
+ soft_temp: float=1.0,
132
+ cache_dir: Union[str,None]=None) -> None:
133
+ self.num_clusters = num_clusters
134
+ self.desc_dim = desc_dim
135
+ self.intra_norm = intra_norm
136
+ self.norm_descs = norm_descs
137
+ self.mode = dist_mode
138
+ self.vlad_mode = str(vlad_mode).lower()
139
+ assert self.vlad_mode in ['soft', 'hard']
140
+ self.soft_temp = soft_temp
141
+ # Set in the training phase
142
+ self.c_centers = None
143
+ self.kmeans = None
144
+ # Set the caching
145
+ self.cache_dir = cache_dir
146
+ if self.cache_dir is not None:
147
+ self.cache_dir = os.path.abspath(os.path.expanduser(
148
+ self.cache_dir))
149
+ if not os.path.exists(self.cache_dir):
150
+ os.makedirs(self.cache_dir)
151
+ print(f"Created cache directory: {self.cache_dir}")
152
+ else:
153
+ print("Warning: Cache directory already exists: " \
154
+ f"{self.cache_dir}")
155
+ else:
156
+ print("VLAD caching is disabled.")
157
+
158
+ def can_use_cache_vlad(self):
159
+ """
160
+ Checks if the cache directory is a valid cache directory.
161
+ For it to be valid, it must exist and should at least
162
+ include the cluster centers file.
163
+
164
+ Returns:
165
+ - True if the cache directory is valid
166
+ - False if
167
+ - the cache directory doesn't exist
168
+ - exists but doesn't contain the cluster centers
169
+ - no caching is set in constructor
170
+ """
171
+ if self.cache_dir is None:
172
+ return False
173
+ if not os.path.exists(self.cache_dir):
174
+ return False
175
+ if os.path.exists(f"{self.cache_dir}/c_centers.pt"):
176
+ return True
177
+ else:
178
+ return False
179
+
180
+ def can_use_cache_ids(self,
181
+ cache_ids: Union[List[str], str, None],
182
+ only_residuals: bool=False) -> bool:
183
+ """
184
+ Checks if the given cache IDs exist in the cache directory
185
+ and returns True if all of them exist.
186
+ The cache is stored in the following files:
187
+ - c_centers.pt: Cluster centers
188
+ - `cache_id`_r.pt: Residuals for VLAD
189
+ - `cache_id`_l.pt: Labels for VLAD (hard assignment)
190
+ - `cache_id`_s.pt: Soft assignment for VLAD
191
+
192
+ The function returns False if cache cannot be used or if
193
+ any of the cache IDs are not found. If all cache IDs are
194
+ found, then True is returned.
195
+
196
+ This function is mainly for use outside the VLAD class.
197
+ """
198
+ if not self.can_use_cache_vlad():
199
+ return False
200
+ if cache_ids is None:
201
+ return False
202
+ if isinstance(cache_ids, str):
203
+ cache_ids = [cache_ids]
204
+ for cache_id in cache_ids:
205
+ if not os.path.exists(
206
+ f"{self.cache_dir}/{cache_id}_r.pt"):
207
+ return False
208
+ if self.vlad_mode == "hard" and not os.path.exists(
209
+ f"{self.cache_dir}/{cache_id}_l.pt") and not \
210
+ only_residuals:
211
+ return False
212
+ if self.vlad_mode == "soft" and not os.path.exists(
213
+ f"{self.cache_dir}/{cache_id}_s.pt") and not \
214
+ only_residuals:
215
+ return False
216
+ return True
217
+
218
+ # Generate cluster centers
219
+ def fit(self, train_descs: Union[np.ndarray, torch.Tensor, None]):
220
+ """
221
+ Using the training descriptors, generate the cluster
222
+ centers (vocabulary). Function expects all descriptors in
223
+ a single list (see `fit_and_generate` for a batch of
224
+ images).
225
+ If the cache directory is valid, then retrieves cluster
226
+ centers from there (the `train_descs` are ignored).
227
+ Otherwise, stores the cluster centers in the cache
228
+ directory (if using caching).
229
+
230
+ Parameters:
231
+ - train_descs: Training descriptors of shape
232
+ [num_train_desc, desc_dim]. If None, then
233
+ caching should be valid (else ValueError).
234
+ """
235
+ # Clustering to create vocabulary
236
+ self.kmeans = fpk.KMeans(self.num_clusters, mode=self.mode)
237
+ # Check if cache exists
238
+ if self.can_use_cache_vlad():
239
+ print("Using cached cluster centers")
240
+ self.c_centers = torch.load(
241
+ f"{self.cache_dir}/c_centers.pt")
242
+ self.kmeans.centroids = self.c_centers
243
+ if self.desc_dim is None:
244
+ self.desc_dim = self.c_centers.shape[1]
245
+ print(f"Desc dim set to {self.desc_dim}")
246
+ else:
247
+ if train_descs is None:
248
+ raise ValueError("No training descriptors given")
249
+ if type(train_descs) == np.ndarray:
250
+ train_descs = torch.from_numpy(train_descs).\
251
+ to(torch.float32)
252
+ if self.desc_dim is None:
253
+ self.desc_dim = train_descs.shape[1]
254
+ if self.norm_descs:
255
+ train_descs = F.normalize(train_descs)
256
+ self.kmeans.fit(train_descs)
257
+ self.c_centers = self.kmeans.centroids
258
+ if self.cache_dir is not None:
259
+ print("Caching cluster centers")
260
+ torch.save(self.c_centers,
261
+ f"{self.cache_dir}/c_centers.pt")
262
+
263
+ def fit_and_generate(self,
264
+ train_descs: Union[np.ndarray, torch.Tensor]) \
265
+ -> torch.Tensor:
266
+ """
267
+ Given a batch of descriptors over images, `fit` the VLAD
268
+ and generate the global descriptors for the training
269
+ images. Use only when there are a fixed number of
270
+ descriptors in each image.
271
+
272
+ Parameters:
273
+ - train_descs: Training image descriptors of shape
274
+ [num_imgs, num_descs, desc_dim]. There are
275
+ 'num_imgs' images, each image has
276
+ 'num_descs' descriptors and each
277
+ descriptor is 'desc_dim' dimensional.
278
+
279
+ Returns:
280
+ - train_vlads: The VLAD vectors of all training images.
281
+ Shape: [num_imgs, num_clusters*desc_dim]
282
+ """
283
+ # Generate vocabulary
284
+ all_descs = ein.rearrange(train_descs, "n k d -> (n k) d")
285
+ self.fit(all_descs)
286
+ # For each image, stack VLAD
287
+ return torch.stack([self.generate(tr) for tr in train_descs])
288
+
289
+ def generate(self, query_descs: Union[np.ndarray, torch.Tensor],
290
+ cache_id: Union[str, None]=None) -> torch.Tensor:
291
+ """
292
+ Given the query descriptors, generate a VLAD vector. Call
293
+ `fit` before using this method. Use this for only single
294
+ images and with descriptors stacked. Use function
295
+ `generate_multi` for multiple images.
296
+
297
+ Parameters:
298
+ - query_descs: Query descriptors of shape [n_q, desc_dim]
299
+ where 'n_q' is number of 'desc_dim'
300
+ dimensional descriptors in a query image.
301
+ - cache_id: If not None, then the VLAD vector is
302
+ constructed using the residual and labels
303
+ from this file.
304
+
305
+ Returns:
306
+ - n_vlas: Normalized VLAD: [num_clusters*desc_dim]
307
+ """
308
+ residuals = self.generate_res_vec(query_descs, cache_id)
309
+ # Un-normalized VLAD vector: [c*d,]
310
+ un_vlad = torch.zeros(self.num_clusters * self.desc_dim)
311
+ if self.vlad_mode == 'hard':
312
+ # Get labels for assignment of descriptors
313
+ if cache_id is not None and self.can_use_cache_vlad() \
314
+ and os.path.isfile(
315
+ f"{self.cache_dir}/{cache_id}_l.pt"):
316
+ labels = torch.load(
317
+ f"{self.cache_dir}/{cache_id}_l.pt")
318
+ else:
319
+ labels = self.kmeans.predict(query_descs) # [q]
320
+ if cache_id is not None and self.can_use_cache_vlad():
321
+ torch.save(labels,
322
+ f"{self.cache_dir}/{cache_id}_l.pt")
323
+ # Create VLAD from residuals and labels
324
+ used_clusters = set(labels.numpy())
325
+ for k in used_clusters:
326
+ # Sum of residuals for the descriptors in the cluster
327
+ # Shape:[q, c, d] -> [q', d] -> [d]
328
+ cd_sum = residuals[labels==k,k].sum(dim=0)
329
+ if self.intra_norm:
330
+ cd_sum = F.normalize(cd_sum, dim=0)
331
+ un_vlad[k*self.desc_dim:(k+1)*self.desc_dim] = cd_sum
332
+ else: # Soft cluster assignment
333
+ # Cosine similarity: 1 = close, -1 = away
334
+ if cache_id is not None and self.can_use_cache_vlad() \
335
+ and os.path.isfile(
336
+ f"{self.cache_dir}/{cache_id}_s.pt"):
337
+ soft_assign = torch.load(
338
+ f"{self.cache_dir}/{cache_id}_s.pt")
339
+ else:
340
+ cos_sims = F.cosine_similarity( # [q, c]
341
+ ein.rearrange(query_descs, "q d -> q 1 d"),
342
+ ein.rearrange(self.c_centers, "c d -> 1 c d"),
343
+ dim=2)
344
+ soft_assign = F.softmax(self.soft_temp*cos_sims,
345
+ dim=1)
346
+ if cache_id is not None and self.can_use_cache_vlad():
347
+ torch.save(soft_assign,
348
+ f"{self.cache_dir}/{cache_id}_s.pt")
349
+ # Soft assignment scores (as probabilities): [q, c]
350
+ for k in range(0, self.num_clusters):
351
+ w = ein.rearrange(soft_assign[:, k], "q -> q 1 1")
352
+ # Sum of residuals for all descriptors (for cluster k)
353
+ cd_sum = ein.rearrange(w * residuals,
354
+ "q c d -> (q c) d").sum(dim=0) # [d]
355
+ if self.intra_norm:
356
+ cd_sum = F.normalize(cd_sum, dim=0)
357
+ un_vlad[k*self.desc_dim:(k+1)*self.desc_dim] = cd_sum
358
+ # Normalize the VLAD vector
359
+ n_vlad = F.normalize(un_vlad, dim=0)
360
+ return n_vlad
361
+
362
+ def generate_multi(self,
363
+ multi_query: Union[np.ndarray, torch.Tensor, list],
364
+ cache_ids: Union[List[str], None]=None) \
365
+ -> Union[torch.Tensor, list]:
366
+ """
367
+ Given query descriptors from multiple images, generate
368
+ the VLAD for them.
369
+
370
+ Parameters:
371
+ - multi_query: Descriptors of shape [n_imgs, n_kpts, d]
372
+ There are 'n_imgs' and each image has
373
+ 'n_kpts' keypoints, with 'd' dimensional
374
+ descriptor each. If a List (can then have
375
+ different number of keypoints in each
376
+ image), then the result is also a list.
377
+ - cache_ids: Cache IDs for the VLAD vectors. If None,
378
+ then no caching is done (stored or
379
+ retrieved). If a list, then the length
380
+ should be 'n_imgs' (one per image).
381
+
382
+ Returns:
383
+ - multi_res: VLAD descriptors for the queries
384
+ """
385
+ if cache_ids is None:
386
+ cache_ids = [None] * len(multi_query)
387
+ res = [self.generate(q, c) \
388
+ for (q, c) in zip(multi_query, cache_ids)]
389
+ try: # Most likely pytorch
390
+ res = torch.stack(res)
391
+ except TypeError:
392
+ try: # Otherwise numpy
393
+ res = np.stack(res)
394
+ except TypeError:
395
+ pass # Let it remain as a list
396
+ return res
397
+
398
+ def generate_res_vec(self,
399
+ query_descs: Union[np.ndarray, torch.Tensor],
400
+ cache_id: Union[str, None]=None) -> torch.Tensor:
401
+ """
402
+ Given the query descriptors, generate a VLAD vector. Call
403
+ `fit` before using this method. Use this for only single
404
+ images and with descriptors stacked. Use function
405
+ `generate_multi` for multiple images.
406
+
407
+ Parameters:
408
+ - query_descs: Query descriptors of shape [n_q, desc_dim]
409
+ where 'n_q' is number of 'desc_dim'
410
+ dimensional descriptors in a query image.
411
+ - cache_id: If not None, then the VLAD vector is
412
+ constructed using the residual and labels
413
+ from this file.
414
+
415
+ Returns:
416
+ - residuals: Residual vector: shape [n_q, n_c, d]
417
+ """
418
+ assert self.kmeans is not None
419
+ assert self.c_centers is not None
420
+ # Compute residuals (all query to cluster): [q, c, d]
421
+ if cache_id is not None and self.can_use_cache_vlad() and \
422
+ os.path.isfile(f"{self.cache_dir}/{cache_id}_r.pt"):
423
+ residuals = torch.load(
424
+ f"{self.cache_dir}/{cache_id}_r.pt")
425
+ else:
426
+ if type(query_descs) == np.ndarray:
427
+ query_descs = torch.from_numpy(query_descs)\
428
+ .to(torch.float32)
429
+ if self.norm_descs:
430
+ query_descs = F.normalize(query_descs)
431
+ residuals = ein.rearrange(query_descs, "q d -> q 1 d") \
432
+ - ein.rearrange(self.c_centers, "c d -> 1 c d")
433
+ if cache_id is not None and self.can_use_cache_vlad():
434
+ cid_dir = f"{self.cache_dir}/"\
435
+ f"{os.path.split(cache_id)[0]}"
436
+ if not os.path.isdir(cid_dir):
437
+ os.makedirs(cid_dir)
438
+ print(f"Created directory: {cid_dir}")
439
+ torch.save(residuals,
440
+ f"{self.cache_dir}/{cache_id}_r.pt")
441
+ # print("residuals",residuals.shape)
442
+ return residuals
443
+
444
+ def generate_multi_res_vec(self,
445
+ multi_query: Union[np.ndarray, torch.Tensor, list],
446
+ cache_ids: Union[List[str], None]=None) \
447
+ -> Union[torch.Tensor, list]:
448
+ """
449
+ Given query descriptors from multiple images, generate
450
+ the VLAD for them.
451
+
452
+ Parameters:
453
+ - multi_query: Descriptors of shape [n_imgs, n_kpts, d]
454
+ There are 'n_imgs' and each image has
455
+ 'n_kpts' keypoints, with 'd' dimensional
456
+ descriptor each. If a List (can then have
457
+ different number of keypoints in each
458
+ image), then the result is also a list.
459
+ - cache_ids: Cache IDs for the VLAD vectors. If None,
460
+ then no caching is done (stored or
461
+ retrieved). If a list, then the length
462
+ should be 'n_imgs' (one per image).
463
+
464
+ Returns:
465
+ - multi_res: VLAD descriptors for the queries
466
+ """
467
+ if cache_ids is None:
468
+ cache_ids = [None] * len(multi_query)
469
+ res = [self.generate_res_vec(q, c) \
470
+ for (q, c) in zip(multi_query, cache_ids)]
471
+ try: # Most likely pytorch
472
+ res = torch.stack(res)
473
+ except TypeError:
474
+ try: # Otherwise numpy
475
+ res = np.stack(res)
476
+ except TypeError:
477
+ pass # Let it remain as a list
478
+ return res