import json import math import pickle import sys import time from datetime import datetime from pathlib import Path from typing import Any, Dict, Mapping import cv2 import matplotlib.cm as cm import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T import tqdm from PIL import Image from pytorch_lightning.loggers import TensorBoardLogger from sklearn.decomposition import PCA from torch.nn.parameter import Parameter from torch.utils.data import ConcatDataset, DataLoader, Subset from torchvision.transforms import functional class _LoRA_qkv(nn.Module): """ In Dinov2 it is implemented as self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) """ def __init__( self, qkv: nn.Module, linear_a_q: nn.Module, linear_b_q: nn.Module, linear_a_v: nn.Module, linear_b_v: nn.Module, ): super().__init__() self.qkv = qkv self.linear_a_q = linear_a_q self.linear_b_q = linear_b_q self.linear_a_v = linear_a_v self.linear_b_v = linear_b_v self.dim = qkv.in_features self.w_identity = torch.eye(qkv.in_features) def forward(self, x): qkv = self.qkv(x) # B,N,3*org_C new_q = self.linear_b_q(self.linear_a_q(x)) new_v = self.linear_b_v(self.linear_a_v(x)) qkv[:, :, : self.dim] += new_q qkv[:, :, -self.dim:] += new_v return qkv def sigmoid(tensor, temp=1.0): """ temperature controlled sigmoid takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp """ exponent = -tensor / temp # clamp the input tensor for stability exponent = torch.clamp(exponent, min=-50, max=50) y = 1.0 / (1.0 + torch.exp(exponent)) return y def interpolate_features(descriptors, pts, h, w, normalize=True, patch_size=14, stride=14): last_coord_h = ( (h - patch_size) // stride ) * stride + (patch_size / 2) last_coord_w = ( (w - patch_size) // stride ) * stride + (patch_size / 2) ah = 2 / (last_coord_h - (patch_size / 2)) aw = 2 / (last_coord_w - (patch_size / 2)) bh = 1 - last_coord_h * 2 / ( last_coord_h - ( patch_size / 2 )) bw = 1 - last_coord_w * 2 / ( last_coord_w - ( patch_size / 2 )) a = torch.tensor([[aw, ah]]).to(pts).float() b = torch.tensor([[bw, bh]]).to(pts).float() keypoints = a * pts + b # Expand dimensions for grid sampling keypoints = keypoints.unsqueeze(-3) # Shape becomes [batch_size, 1, num_keypoints, 2] # Interpolate using bilinear sampling interpolated_features = F.grid_sample(descriptors, keypoints, align_corners=True, padding_mode='border') # interpolated_features will have shape [batch_size, channels, 1, num_keypoints] interpolated_features = interpolated_features.squeeze(-2) return F.normalize(interpolated_features, dim=1) if normalize else interpolated_features class FinetuneDINO(pl.LightningModule): def __init__(self, r, backbone_size, reg=False, datasets=None): super().__init__() assert r > 0 self.backbone_size = backbone_size self.backbone_archs = { "small": "vits14", "base": "vitb14", "large": "vitl14", "giant": "vitg14", } self.embedding_dims = { "small": 384, "base": 768, "large": 1024, "giant": 1536, } self.backbone_arch = self.backbone_archs[self.backbone_size] if reg: self.backbone_arch = f"{self.backbone_arch}_reg" self.embedding_dim = self.embedding_dims[self.backbone_size] self.backbone_name = f"dinov2_{self.backbone_arch}" dinov2 = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=self.backbone_name) self.datasets = datasets self.lora_layer = list(range(len(dinov2.blocks))) # Only apply lora to the image encoder by default # create for storage, then we can init them or load weights self.w_As = [] # These are linear layers self.w_Bs = [] # freeze first for param in dinov2.parameters(): param.requires_grad = False # finetune the last 4 blocks for t_layer_i, blk in enumerate(dinov2.blocks[-4:]): # If we only want few lora layer instead of all if t_layer_i not in self.lora_layer: continue w_qkv_linear = blk.attn.qkv self.dim = w_qkv_linear.in_features w_a_linear_q = nn.Linear(self.dim, r, bias=False) w_b_linear_q = nn.Linear(r, self.dim, bias=False) w_a_linear_v = nn.Linear(self.dim, r, bias=False) w_b_linear_v = nn.Linear(r, self.dim, bias=False) self.w_As.append(w_a_linear_q) self.w_Bs.append(w_b_linear_q) self.w_As.append(w_a_linear_v) self.w_Bs.append(w_b_linear_v) blk.attn.qkv = _LoRA_qkv( w_qkv_linear, w_a_linear_q, w_b_linear_q, w_a_linear_v, w_b_linear_v, ) self.reset_parameters() self.dinov2 = dinov2 self.downsample_factor = 8 self.refine_conv = nn.Conv2d(self.embedding_dim, self.embedding_dim, kernel_size=3, stride=1, padding=1) self.thresh3d_pos = 5e-3 self.thres3d_neg = 0.1 self.patch_size = 14 self.target_res = 640 self.input_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) def reset_parameters(self) -> None: for w_A in self.w_As: nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) for w_B in self.w_Bs: nn.init.zeros_(w_B.weight) def on_save_checkpoint(self, checkpoint: Dict[str, Any]): num_layer = len(self.w_As) # actually, it is half a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} checkpoint['state_dict'] = { 'refine_conv': self.refine_conv.state_dict(), } checkpoint.update(a_tensors) checkpoint.update(b_tensors) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): pass def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # print(checkpoint.keys()) self.refine_conv.load_state_dict(checkpoint['state_dict']['refine_conv']) for i, w_A_linear in enumerate(self.w_As): saved_key = f"w_a_{i:03d}" saved_tensor = checkpoint[saved_key] w_A_linear.weight = Parameter(saved_tensor) for i, w_B_linear in enumerate(self.w_Bs): saved_key = f"w_b_{i:03d}" saved_tensor = checkpoint[saved_key] w_B_linear.weight = Parameter(saved_tensor) self.loaded = True def get_nearest(self, query, database): dist = torch.cdist(query, database) min_dist, min_idx = torch.min(dist, -1) return min_dist, min_idx def get_feature(self, rgbs, pts, normalize=True): tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res) if rgbs.shape[-2] > rgbs.shape[-1]: tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2])) patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size)) resize_factor = [(patch_w * self.patch_size) / rgbs.shape[-1], (patch_h * self.patch_size) / rgbs.shape[-2]] pts = pts * torch.tensor(resize_factor).to(pts.device) result = self.dinov2.forward_features(self.input_transform(rgb_resized)) feature = result['x_norm_patchtokens'].reshape(rgb_resized.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2) feature = self.refine_conv(feature) feature = interpolate_features(feature, pts, h=patch_h * 14, w=patch_w * 14, normalize=False).permute(0, 2, 1) if normalize: feature = F.normalize(feature, p=2, dim=-1) return feature def get_feature_wo_kp(self, rgbs, normalize=True): tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res) if rgbs.shape[-2] > rgbs.shape[-1]: tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2])) patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size)) result = self.dinov2.forward_features(self.input_transform(rgb_resized)) feature = result['x_norm_patchtokens'].reshape(rgbs.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2) feature = self.refine_conv(feature) feature = functional.resize(feature, (rgbs.shape[-2], rgbs.shape[-1])).permute(0, 2, 3, 1) if normalize: feature = F.normalize(feature, p=2, dim=-1) return feature def training_step(self, batch, batch_idx): # print(batch['obj_name_1']) rgb_1, pts2d_1, pts3d_1 = batch['rgb_1'], batch['pts2d_1'], batch['pts3d_1'] rgb_2, pts2d_2, pts3d_2 = batch['rgb_2'], batch['pts2d_2'], batch['pts3d_2'] desc_1 = self.get_feature(rgb_1, pts2d_1, normalize=True) desc_2 = self.get_feature(rgb_2, pts2d_2, normalize=True) kp3d_dist = torch.cdist(pts3d_1, pts3d_2) # B x S x T sim = torch.bmm(desc_1, desc_2.transpose(-1, -2)) # B x S x T pos_idxs = torch.nonzero(kp3d_dist < self.thresh3d_pos, as_tuple=False) pos_sim = sim[pos_idxs[:, 0], pos_idxs[:, 1], pos_idxs[:, 2]] rpos = sigmoid(pos_sim - 1., temp=0.01) + 1 # si = 1 # pos neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - 1., temp=0.01) * neg_mask.float(), -1) # pos ap1 = rpos / rall # change teh order rpos = sigmoid(1. - pos_sim, temp=0.01) + 1 # si = 1 # pos neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - pos_sim[:, None].repeat(1, sim.shape[-1]), temp=0.01) * neg_mask.float(), -1) # pos ap2 = rpos / rall ap = (ap1 + ap2) / 2 loss = torch.mean(1. - ap) self.log('loss', loss, prog_bar=True) return loss def configure_optimizers(self): return torch.optim.AdamW([layer.weight for layer in self.w_As] + [layer.weight for layer in self.w_Bs] + list(self.refine_conv.parameters()), lr=1e-5, weight_decay=1e-4)