Spaces:
Runtime error
Runtime error
# Show VLAD clustering for set of example images or a user image | |
""" | |
User input: | |
- Domain: Indoor, Aerial, or Urban | |
- Image: Image to be clustered | |
- Cluster numbers (to visualize) | |
- Pixel coordinates (to pick further clusters) | |
- A unique cache ID (to store the DINO forward passes) | |
There are example images for each domain. | |
Output: | |
- All images with cluster assignments | |
Some Gradio links: | |
- Controlling layout | |
- https://www.gradio.app/guides/quickstart#blocks-more-flexibility-and-control | |
- Data state (persistence) | |
- https://www.gradio.app/guides/interface-state | |
- https://www.gradio.app/docs/state | |
- Layout control | |
- https://www.gradio.app/guides/controlling-layout | |
- https://www.gradio.app/guides/blocks-and-event-listeners | |
""" | |
# %% | |
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 as cv | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms as tvf | |
from torchvision.transforms import functional as T | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import distinctipy as dipy | |
from typing import Literal, List | |
import gradio as gr | |
import time | |
import glob | |
import shutil | |
from copy import deepcopy | |
# DINOv2 imports | |
from utilities import DinoV2ExtractFeatures | |
from utilities import VLAD | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# %% | |
# Configurations | |
T1 = Literal["query", "key", "value", "token"] | |
T2 = Literal["aerial", "indoor", "urban"] | |
DOMAINS = ["aerial", "indoor", "urban"] | |
T3 = Literal["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", | |
"dinov2_vitg14"] | |
_ex = lambda x: os.path.realpath(os.path.expanduser(x)) | |
dino_model: T3 = "dinov2_vitg14" | |
desc_layer: int = 31 | |
desc_facet: T1 = "value" | |
num_c: int = 8 | |
cache_dir: str = _ex("./cache") # Directory containing program cache | |
max_img_size: int = 1024 # Image resolution (max dim/size) | |
max_num_imgs: int = 10 # Max number of images to upload | |
share: bool = False # Share application using .gradio link | |
# Verify inputs | |
assert os.path.isdir(cache_dir), "Cache directory not found" | |
# %% | |
# Model and transforms | |
print("Loading DINO model") | |
# extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet, | |
# device=device) | |
extractor = None | |
print("DINO model loaded") | |
# VLAD path (directory) | |
ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}" | |
vc_dir = os.path.join(cache_dir, "vocabulary", ext_s) | |
# Base image transformations | |
base_tf = tvf.Compose([ | |
tvf.ToTensor(), | |
tvf.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
# %% | |
# Get VLAD object | |
def get_vlad_clusters(domain, pr = gr.Progress()): | |
dm: T2 = str(domain).lower() | |
assert dm in DOMAINS, "Invalid domain" | |
# Load VLAD cluster centers | |
pr(0, desc="Loading VLAD clusters") | |
c_centers_file = os.path.join(vc_dir, dm, "c_centers.pt") | |
if not os.path.isfile(c_centers_file): | |
return f"Cluster centers not found for: {domain}", None | |
c_centers = torch.load(c_centers_file) | |
pr(0.5) | |
num_c = c_centers.shape[0] | |
desc_dim = c_centers.shape[1] | |
vlad = VLAD(num_c, desc_dim, | |
cache_dir=os.path.dirname(c_centers_file)) | |
vlad.fit(None) # Restore the cache | |
pr(1) | |
return f"VLAD clusters loaded for: {domain}", vlad | |
# %% | |
# Get VLAD descriptors | |
def get_descs(imgs_batch, pr = gr.Progress()): | |
imgs_batch: List[np.ndarray] = imgs_batch | |
pr(0, desc="Extracting descriptors") | |
patch_descs = [] | |
for i, img in enumerate(imgs_batch): | |
# Convert to PIL image | |
pil_img = Image.fromarray(img) | |
img_pt = base_tf(pil_img).to(device) | |
if max(img_pt.shape[-2:]) > max_img_size: | |
print(f"Image {i+1}: {img_pt.shape[-2:]}, outside") | |
c, h, w = img_pt.shape | |
# Maintain aspect ratio | |
if h == max(img_pt.shape[-2:]): | |
w = int(w * max_img_size / h) | |
h = max_img_size | |
else: | |
h = int(h * max_img_size / w) | |
w = max_img_size | |
img_pt = T.resize(img_pt, (h, w), | |
interpolation=T.InterpolationMode.BICUBIC) | |
pil_img = pil_img.resize((w, h)) # Backup | |
# Make image patchable | |
c, h, w = img_pt.shape | |
h_new, w_new = (h // 14) * 14, (w // 14) * 14 | |
img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...] | |
# Extract descriptors | |
ret = extractor(img_pt).cpu() # [1, n_p, d] | |
patch_descs.append({"img": pil_img, "descs": ret}) | |
pr((i+1) / len(imgs_batch)) | |
return patch_descs, \ | |
f"Descriptors extracted for {len(imgs_batch)} images" | |
# %% | |
# Assign VLAD clusters (descriptor assignment) | |
def assign_vlad(patch_descs, vlad, pr = gr.Progress()): | |
vlad: VLAD = vlad | |
img_patch_descs = [pd["descs"] for pd in patch_descs] | |
pr(0, desc="Assigning VLAD clusters") | |
desc_assignments = [] # List[Tensor;shape=('h', 'w');int] | |
for i, qu_desc in enumerate(img_patch_descs): | |
# Residual vectors; 'n' could differ (based on img sizes) | |
res = vlad.generate_res_vec(qu_desc[0]) # ['n', n_c, d] | |
img = patch_descs[i]["img"] | |
h, w, c = np.array(img).shape | |
h_p, w_p = h // 14, w // 14 | |
h_new, w_new = h_p * 14, w_p * 14 | |
assert h_p * w_p == res.shape[0], "Residual incorrect!" | |
# Descriptor assignments | |
da = res.abs().sum(dim=2).argmin(dim=1).reshape(h_p, w_p) | |
da = F.interpolate(da[None, None, ...].to(float), | |
(h_new, w_new), mode="nearest")[0, 0].to(da.dtype) | |
desc_assignments.append(da) | |
pr((i+1) / len(img_patch_descs)) | |
pr(1.0) | |
return desc_assignments, "VLAD clusters assigned" | |
# %% | |
# Cluster assignments to images | |
def get_ca_images(desc_assignments, patch_descs, alpha, | |
pr = gr.Progress()): | |
if desc_assignments is None or len(desc_assignments) == 0: | |
return None, "First load images" | |
c_colors = dipy.get_colors(num_c, rng=928, | |
colorblind_type="Deuteranomaly") | |
np_colors = (np.array(c_colors) * 255).astype(np.uint8) | |
# Get images with clusters | |
pil_imgs = [pd["img"] for pd in patch_descs] | |
res_imgs = [] # List[PIL.Image] | |
pr(0, desc="Generating cluster assignment images") | |
for i, pil_img in enumerate(pil_imgs): | |
# Descriptor assignment image: [h, w, 3] | |
da: torch.Tensor = desc_assignments[i] # ['h', 'w'] | |
da_img = np.zeros((*da.shape, 3), dtype=np.uint8) | |
for c in range(num_c): | |
da_img[da == c] = np_colors[c] | |
# Background image: [h, w, 3] | |
img_np = np.array(pil_img, dtype=np.uint8) | |
h, w, c = np.array(img_np).shape | |
h_p, w_p = (h // 14), (w // 14) | |
h_new, w_new = h_p * 14, w_p * 14 | |
img_np = F.interpolate(torch.tensor(img_np)\ | |
.permute(2, 0, 1)[None, ...], (h_new, w_new), | |
mode='nearest')[0].permute(1, 2, 0).numpy() | |
res_img = cv.addWeighted(img_np, 1 - alpha, da_img, alpha, 0.) | |
res_imgs.append(Image.fromarray(res_img)) | |
pr((i+1) / len(pil_imgs)) | |
pr(1.0) | |
return res_imgs, "Cluster assignment images generated" | |
# %% | |
print("Interface build started") | |
# Build the interface | |
with gr.Blocks() as demo: | |
# ---- Helper functions ---- | |
# Variable number of input images | |
def var_num_img(s): | |
n = int(s) # Slider value as int | |
return [gr.Image.update(label=f"Image {i+1}", visible=True) \ | |
for i in range(n)] + [gr.Image.update(visible=False) \ | |
for _ in range(max_num_imgs - n)] | |
# ---- State declarations ---- | |
vlad = gr.State() # VLAD object | |
desc_assignments = gr.State() # Cluster assignments | |
imgs_batch = gr.State() # Images as batch | |
patch_descs = gr.State() # Patch descriptors | |
# ---- All UI elements ---- | |
d_vals = [k.title() for k in DOMAINS] | |
domain = gr.Radio(d_vals, value=d_vals[0]) | |
nimg_s = gr.Slider(1, max_num_imgs, value=1, step=1, | |
label="How many images?") # How many images? | |
with gr.Row(): # Dynamic row (images in columns) | |
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \ | |
for i in range(nimg_s.value)] + \ | |
[gr.Image(visible=False) \ | |
for _ in range(max_num_imgs - nimg_s.value)] | |
for i, img in enumerate(imgs): # Set image as "input" | |
img.change(lambda _: None, img) | |
with gr.Row(): # Dynamic row of output (cluster) images | |
imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}", | |
visible=False) for i in range(max_num_imgs)] | |
nimg_s.change(var_num_img, nimg_s, imgs) | |
blend_alpha = gr.Slider(0, 1, 0.4, step=0.01, # Cluster centers | |
label="Blend alpha (weight for cluster centers)") | |
bttn1 = gr.Button("Click Me!") # Cluster assignment | |
out_msg1 = gr.Markdown("Select domain and upload images") | |
out_msg2 = gr.Markdown("For descriptor extraction") | |
out_msg3 = gr.Markdown("Followed by VLAD assignment") | |
out_msg4 = gr.Markdown("Followed by cluster images") | |
# ---- Utility functions ---- | |
# A wrapper to batch the images | |
def batch_images(data): | |
sv = data[nimg_s] | |
images: List[np.ndarray] = [data[imgs[k]] \ | |
for k in range(sv)] | |
return images | |
# A wrapper to unbatch images (and pad to max) | |
def unbatch_images(imgs_batch): | |
ret = [gr.Image.update(visible=False) \ | |
for _ in range(max_num_imgs)] | |
if imgs_batch is None or len(imgs_batch) == 0: | |
return ret | |
for i, img_pil in enumerate(imgs_batch): | |
img_np = np.array(img_pil) | |
ret[i] = gr.Image.update(img_np, visible=True) | |
return ret | |
# ---- Main pipeline ---- | |
# Get the VLAD cluster assignment images on click | |
bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\ | |
.then(batch_images, {nimg_s, *imgs, imgs_batch}, imgs_batch)\ | |
.then(get_descs, imgs_batch, [patch_descs, out_msg2])\ | |
.then(assign_vlad, [patch_descs, vlad], | |
[desc_assignments, out_msg3])\ | |
.then(get_ca_images, | |
[desc_assignments, patch_descs, blend_alpha], | |
[imgs_batch, out_msg4])\ | |
.then(unbatch_images, imgs_batch, imgs2) | |
# If the blending changes now, update the cluster images | |
blend_alpha.change(get_ca_images, | |
[desc_assignments, patch_descs, blend_alpha], | |
[imgs_batch, out_msg4])\ | |
.then(unbatch_images, imgs_batch, imgs2) | |
print("Interface build completed") | |
# %% | |
# Deploy application | |
demo.queue().launch(share=share) | |
print("Application deployment ended, exiting...") | |