Spaces:
Runtime error
Runtime error
import os | |
# gradio for visual demo | |
import gradio as gr | |
# transformers for easy access to nnet | |
os.system("pip install scipy") | |
os.system("pip install torch") | |
os.system("pip install scikit-learn") | |
os.system("pip install torchvision") | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import ImageDraw, ImageColor, Image | |
from typing import Tuple | |
from scipy.ndimage import binary_closing, binary_opening | |
from sklearn.decomposition import PCA | |
from sklearn.neighbors import NearestNeighbors | |
from random import randint | |
### Models | |
standard_array = np.load('standard.npy') | |
pca_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") | |
pca_model.eval() | |
### Parameters | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
smaller_edge_size = 448 | |
interpolation_mode = transforms.InterpolationMode.BICUBIC | |
patch_size = pca_model.patch_size | |
background_threshold = 0.05 | |
apply_opening = True | |
apply_closing = True | |
device = 'cpu' | |
def make_transform() -> transforms.Compose: | |
return transforms.Compose([ | |
transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
]) | |
def prepare_image(image: Image) -> Tuple[torch.Tensor, Tuple[int, int]]: | |
transform = make_transform() | |
image_tensor = transform(image) | |
resize_scale = image.width / image_tensor.shape[2] | |
# Crop image to dimensions that are a multiple of the patch size | |
height, width = image_tensor.shape[1:] # C x H x W | |
cropped_width, cropped_height = width - width % patch_size, height - height % patch_size | |
image_tensor = image_tensor[:, :cropped_height, :cropped_width] | |
grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check) | |
return image_tensor, grid_size, resize_scale | |
def make_foreground_mask(tokens, | |
grid_size: Tuple[int, int]): | |
projection = tokens @ standard_array | |
mask = projection >= background_threshold | |
mask = mask.reshape(*grid_size) | |
if apply_opening: | |
mask = binary_opening(mask) | |
if apply_closing: | |
mask = binary_closing(mask) | |
return mask.flatten() | |
def render_patch_pca(image: Image, mask, filter_background, tokens, grid_size) -> Image: | |
pca = PCA(n_components=3) | |
if filter_background : pca.fit(tokens[mask]) | |
else : pca.fit(tokens) | |
projected_tokens = pca.transform(tokens) | |
t = torch.tensor(projected_tokens) | |
t_min = t.min(dim=0, keepdim=True).values | |
t_max = t.max(dim=0, keepdim=True).values | |
normalized_t = (t - t_min) / (t_max - t_min) | |
array = (normalized_t * 255).byte().numpy() | |
if filter_background : array[~mask] = 0 | |
array = array.reshape(*grid_size, 3) | |
return Image.fromarray(array).resize((image.width, image.height), 0) | |
def extract_features(img, image_tensor, filter_background, grid_size): | |
with torch.inference_mode(): | |
image_batch = image_tensor.unsqueeze(0).to(device) | |
tokens = pca_model.get_intermediate_layers(image_batch)[0].squeeze() | |
mask = make_foreground_mask(tokens, grid_size) | |
img_pca = render_patch_pca(img, mask, filter_background, tokens, grid_size) | |
return tokens.cpu().numpy(), mask, img_pca | |
def compute_features(img, filter_background): | |
image_tensor, grid_size, resize_scale = prepare_image(img) | |
features, mask, img_pca = extract_features(img, image_tensor, filter_background, grid_size) | |
return features, mask, grid_size, resize_scale, img_pca | |
def idx_to_source_position(idx, grid_size, resize_scale): | |
row = (idx // grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2 | |
col = (idx % grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2 | |
return row, col | |
def compute_nn(features1, features2): | |
knn = NearestNeighbors(n_neighbors=1) | |
knn.fit(features1) | |
distances, match2to1 = knn.kneighbors(features2) | |
match2to1 = np.array(match2to1) | |
return distances, match2to1 | |
def compute_matches(img1, img2, lr_check, filter_background, display_matches_threshold): | |
# compute features | |
features1, mask1, grid_size1, resize_scale1, img_pca1 = compute_features(img1, filter_background) | |
features2, mask2, grid_size2, resize_scale2, img_pca2 = compute_features(img2, filter_background) | |
# match features | |
distances2to1, match2to1 = compute_nn(features1, features2) | |
distances1to2, match1to2 = compute_nn(features2, features1) | |
# display matches | |
draw1 = ImageDraw.Draw(img1) | |
draw2 = ImageDraw.Draw(img2) | |
if(img1.size[1] > img2.size[1]): | |
img1 = img1.resize(img2.size) | |
resize_scale1 = resize_scale2 | |
else: | |
img2 = img2.resize(img1.size) | |
resize_scale2 = resize_scale1 | |
offset = img1.size[0] | |
merged_image = Image.new('RGB',(offset + img2.size[0], max(img1.size[1], img2.size[1])), (250,250,250)) | |
merged_image.paste(img1,(0,0)) | |
merged_image.paste(img2,(offset,0)) | |
draw = ImageDraw.Draw(merged_image) | |
colormap = ImageColor.colormap | |
for idx2, idx1 in enumerate(match2to1): | |
if lr_check and match1to2[idx1] != idx2: continue | |
row1, col1 = idx_to_source_position(idx1, grid_size1, resize_scale1) | |
row2, col2 = idx_to_source_position(idx2, grid_size2, resize_scale2) | |
if filter_background and not mask1[idx1]: continue | |
if filter_background and not mask2[idx2]: continue | |
r = randint(0,255) | |
g = randint(0,255) | |
color = (r,g,255-r) | |
draw1.point((col1, row1), color) | |
draw2.point((col2, row2), color) | |
if 100*np.random.rand() > display_matches_threshold: continue | |
draw.line((col1, row1, col2 + offset, row2), fill=color) | |
return [[img1, img_pca1], [img2, img_pca2], merged_image] | |
iface = gr.Interface(fn=compute_matches, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Image(type="pil"), | |
gr.Checkbox(label="Keep only symmetric matches",), | |
gr.Checkbox(label="Mask background"), | |
gr.Slider(0, 100, step=5, value=5, label="Display matches ratio", info="Choose between 0 and 100%"), | |
], | |
outputs=[ | |
gr.Gallery( | |
label="Image 1", show_label=False, elem_id="gallery", | |
columns=[2], rows=[1], object_fit="contain", height="auto"), | |
gr.Gallery( | |
label="Image 1", show_label=False, elem_id="gallery", | |
columns=[2], rows=[1], object_fit="contain", height="auto"), | |
gr.Image(type="pil") | |
]) | |
iface.launch(debug=True) |