Dinov2-Matching / app.py
ysalaun's picture
missing pip install
5840ebd
import os
# gradio for visual demo
import gradio as gr
# transformers for easy access to nnet
os.system("pip install scipy")
os.system("pip install torch")
os.system("pip install scikit-learn")
os.system("pip install torchvision")
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import ImageDraw, ImageColor, Image
from typing import Tuple
from scipy.ndimage import binary_closing, binary_opening
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from random import randint
### Models
standard_array = np.load('standard.npy')
pca_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
pca_model.eval()
### Parameters
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
smaller_edge_size = 448
interpolation_mode = transforms.InterpolationMode.BICUBIC
patch_size = pca_model.patch_size
background_threshold = 0.05
apply_opening = True
apply_closing = True
device = 'cpu'
def make_transform() -> transforms.Compose:
return transforms.Compose([
transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
def prepare_image(image: Image) -> Tuple[torch.Tensor, Tuple[int, int]]:
transform = make_transform()
image_tensor = transform(image)
resize_scale = image.width / image_tensor.shape[2]
# Crop image to dimensions that are a multiple of the patch size
height, width = image_tensor.shape[1:] # C x H x W
cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
image_tensor = image_tensor[:, :cropped_height, :cropped_width]
grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
return image_tensor, grid_size, resize_scale
def make_foreground_mask(tokens,
grid_size: Tuple[int, int]):
projection = tokens @ standard_array
mask = projection >= background_threshold
mask = mask.reshape(*grid_size)
if apply_opening:
mask = binary_opening(mask)
if apply_closing:
mask = binary_closing(mask)
return mask.flatten()
def render_patch_pca(image: Image, mask, filter_background, tokens, grid_size) -> Image:
pca = PCA(n_components=3)
if filter_background : pca.fit(tokens[mask])
else : pca.fit(tokens)
projected_tokens = pca.transform(tokens)
t = torch.tensor(projected_tokens)
t_min = t.min(dim=0, keepdim=True).values
t_max = t.max(dim=0, keepdim=True).values
normalized_t = (t - t_min) / (t_max - t_min)
array = (normalized_t * 255).byte().numpy()
if filter_background : array[~mask] = 0
array = array.reshape(*grid_size, 3)
return Image.fromarray(array).resize((image.width, image.height), 0)
def extract_features(img, image_tensor, filter_background, grid_size):
with torch.inference_mode():
image_batch = image_tensor.unsqueeze(0).to(device)
tokens = pca_model.get_intermediate_layers(image_batch)[0].squeeze()
mask = make_foreground_mask(tokens, grid_size)
img_pca = render_patch_pca(img, mask, filter_background, tokens, grid_size)
return tokens.cpu().numpy(), mask, img_pca
def compute_features(img, filter_background):
image_tensor, grid_size, resize_scale = prepare_image(img)
features, mask, img_pca = extract_features(img, image_tensor, filter_background, grid_size)
return features, mask, grid_size, resize_scale, img_pca
def idx_to_source_position(idx, grid_size, resize_scale):
row = (idx // grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
col = (idx % grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
return row, col
def compute_nn(features1, features2):
knn = NearestNeighbors(n_neighbors=1)
knn.fit(features1)
distances, match2to1 = knn.kneighbors(features2)
match2to1 = np.array(match2to1)
return distances, match2to1
def compute_matches(img1, img2, lr_check, filter_background, display_matches_threshold):
# compute features
features1, mask1, grid_size1, resize_scale1, img_pca1 = compute_features(img1, filter_background)
features2, mask2, grid_size2, resize_scale2, img_pca2 = compute_features(img2, filter_background)
# match features
distances2to1, match2to1 = compute_nn(features1, features2)
distances1to2, match1to2 = compute_nn(features2, features1)
# display matches
draw1 = ImageDraw.Draw(img1)
draw2 = ImageDraw.Draw(img2)
if(img1.size[1] > img2.size[1]):
img1 = img1.resize(img2.size)
resize_scale1 = resize_scale2
else:
img2 = img2.resize(img1.size)
resize_scale2 = resize_scale1
offset = img1.size[0]
merged_image = Image.new('RGB',(offset + img2.size[0], max(img1.size[1], img2.size[1])), (250,250,250))
merged_image.paste(img1,(0,0))
merged_image.paste(img2,(offset,0))
draw = ImageDraw.Draw(merged_image)
colormap = ImageColor.colormap
for idx2, idx1 in enumerate(match2to1):
if lr_check and match1to2[idx1] != idx2: continue
row1, col1 = idx_to_source_position(idx1, grid_size1, resize_scale1)
row2, col2 = idx_to_source_position(idx2, grid_size2, resize_scale2)
if filter_background and not mask1[idx1]: continue
if filter_background and not mask2[idx2]: continue
r = randint(0,255)
g = randint(0,255)
color = (r,g,255-r)
draw1.point((col1, row1), color)
draw2.point((col2, row2), color)
if 100*np.random.rand() > display_matches_threshold: continue
draw.line((col1, row1, col2 + offset, row2), fill=color)
return [[img1, img_pca1], [img2, img_pca2], merged_image]
iface = gr.Interface(fn=compute_matches,
inputs=[
gr.Image(type="pil"),
gr.Image(type="pil"),
gr.Checkbox(label="Keep only symmetric matches",),
gr.Checkbox(label="Mask background"),
gr.Slider(0, 100, step=5, value=5, label="Display matches ratio", info="Choose between 0 and 100%"),
],
outputs=[
gr.Gallery(
label="Image 1", show_label=False, elem_id="gallery",
columns=[2], rows=[1], object_fit="contain", height="auto"),
gr.Gallery(
label="Image 1", show_label=False, elem_id="gallery",
columns=[2], rows=[1], object_fit="contain", height="auto"),
gr.Image(type="pil")
])
iface.launch(debug=True)