Spaces:
Sleeping
Sleeping
import copy | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
from functools import partial | |
from torch.nn import functional as F | |
from croco.models.blocks import Block | |
from dust3r.model import AsymmetricCroCo3DStereo | |
class SpatialMemory(): | |
def __init__(self, norm_q, norm_k, norm_v, mem_dropout=None, | |
long_mem_size=4000, work_mem_size=5, | |
attn_thresh=5e-4, sim_thresh=0.95, | |
save_attn=False): | |
self.norm_q = norm_q | |
self.norm_k = norm_k | |
self.norm_v = norm_v | |
self.mem_dropout = mem_dropout | |
self.attn_thresh = attn_thresh | |
self.long_mem_size = long_mem_size | |
self.work_mem_size = work_mem_size | |
self.top_k = long_mem_size | |
self.save_attn = save_attn | |
self.sim_thresh = sim_thresh | |
self.init_mem() | |
def init_mem(self): | |
self.mem_k = None | |
self.mem_v = None | |
self.mem_c = None | |
self.mem_count = None | |
self.mem_attn = None | |
self.mem_pts = None | |
self.mem_imgs = None | |
self.lm = 0 | |
self.wm = 0 | |
if self.save_attn: | |
self.attn_vis = None | |
def add_mem_k(self, feat): | |
if self.mem_k is None: | |
self.mem_k = feat | |
else: | |
self.mem_k = torch.cat((self.mem_k, feat), dim=1) | |
return self.mem_k | |
def add_mem_v(self, feat): | |
if self.mem_v is None: | |
self.mem_v = feat | |
else: | |
self.mem_v = torch.cat((self.mem_v, feat), dim=1) | |
return self.mem_v | |
def add_mem_c(self, feat): | |
if self.mem_c is None: | |
self.mem_c = feat | |
else: | |
self.mem_c = torch.cat((self.mem_c, feat), dim=1) | |
return self.mem_c | |
def add_mem_pts(self, pts_cur): | |
if pts_cur is not None: | |
if self.mem_pts is None: | |
self.mem_pts = pts_cur | |
else: | |
self.mem_pts = torch.cat((self.mem_pts, pts_cur), dim=1) | |
def add_mem_img(self, img_cur): | |
if img_cur is not None: | |
if self.mem_imgs is None: | |
self.mem_imgs = img_cur | |
else: | |
self.mem_imgs = torch.cat((self.mem_imgs, img_cur), dim=1) | |
def add_mem(self, feat_k, feat_v, pts_cur=None, img_cur=None): | |
if self.mem_count is None: | |
self.mem_count = torch.zeros_like(feat_k[:, :, :1]) | |
self.mem_attn = torch.zeros_like(feat_k[:, :, :1]) | |
else: | |
self.mem_count += 1 | |
self.mem_count = torch.cat((self.mem_count, torch.zeros_like(feat_k[:, :, :1])), dim=1) | |
self.mem_attn = torch.cat((self.mem_attn, torch.zeros_like(feat_k[:, :, :1])), dim=1) | |
self.add_mem_k(feat_k) | |
self.add_mem_v(feat_v) | |
self.add_mem_pts(pts_cur) | |
self.add_mem_img(img_cur) | |
def check_sim(self, feat_k, thresh=0.7): | |
# Do correlation with working memory | |
if self.mem_k is None or thresh==1.0: | |
return False | |
wmem_size = self.wm * 196 | |
# wm: BS, T, 196, C | |
wm = self.mem_k[:, -wmem_size:].reshape(self.mem_k.shape[0], -1, 196, self.mem_k.shape[-1]) | |
feat_k_norm = F.normalize(feat_k, p=2, dim=-1) | |
wm_norm = F.normalize(wm, p=2, dim=-1) | |
corr = torch.einsum('bpc,btpc->btp', feat_k_norm, wm_norm) | |
mean_corr = torch.mean(corr, dim=-1) | |
if mean_corr.max() > thresh: | |
print('Similarity detected:', mean_corr.max()) | |
return True | |
return False | |
def add_mem_check(self, feat_k, feat_v, pts_cur=None, img_cur=None): | |
if self.check_sim(feat_k, thresh=self.sim_thresh): | |
return | |
self.add_mem(feat_k, feat_v, pts_cur, img_cur) | |
self.wm += 1 | |
if self.wm > self.work_mem_size: | |
self.wm -= 1 | |
if self.long_mem_size == 0: | |
self.mem_k = self.mem_k[:, 196:] | |
self.mem_v = self.mem_v[:, 196:] | |
self.mem_count = self.mem_count[:, 196:] | |
self.mem_attn = self.mem_attn[:, 196:] | |
print('Memory pruned:', self.mem_k.shape) | |
else: | |
self.lm += 196 # TODO: Change this to the actual size of the memory bank | |
if self.lm > self.long_mem_size: | |
self.memory_prune() | |
self.lm = self.top_k - self.wm * 196 | |
def memory_read(self, feat, res=True): | |
''' | |
Params: | |
- feat: [bs, p, c] | |
- mem_k: [bs, t, p, c] | |
- mem_v: [bs, t, p, c] | |
- mem_c: [bs, t, p, 1] | |
''' | |
affinity = torch.einsum('bpc,bxc->bpx', self.norm_q(feat), self.norm_k(self.mem_k.reshape(self.mem_k.shape[0], -1, self.mem_k.shape[-1]))) | |
affinity /= torch.sqrt(torch.tensor(feat.shape[-1]).float()) | |
if self.mem_c is not None: | |
affinity = affinity * self.mem_c.view(self.mem_c.shape[0], 1, -1) | |
attn = torch.softmax(affinity, dim=-1) | |
if self.save_attn: | |
if self.attn_vis is None: | |
self.attn_vis = attn.reshape(-1) | |
else: | |
self.attn_vis = torch.cat((self.attn_vis, attn.reshape(-1)), dim=0) | |
if self.mem_dropout is not None: | |
attn = self.mem_dropout(attn) | |
if self.attn_thresh > 0: | |
attn[attn<self.attn_thresh] = 0 | |
attn = attn / attn.sum(dim=-1, keepdim=True) | |
out = torch.einsum('bpx,bxc->bpc', attn, self.norm_v(self.mem_v.reshape(self.mem_v.shape[0], -1, self.mem_v.shape[-1]))) | |
if res: | |
out = out + feat | |
total_attn = torch.sum(attn, dim=-2) | |
self.mem_attn += total_attn[..., None] | |
return out | |
def memory_prune(self): | |
weights = self.mem_attn / self.mem_count | |
weights[self.mem_count<self.work_mem_size+5] = 1e8 | |
num_mem_b = self.mem_k.shape[1] | |
top_k_values, top_k_indices = torch.topk(weights, self.top_k, dim=1) | |
top_k_indices_expanded = top_k_indices.expand(-1, -1, self.mem_k.size(-1)) | |
self.mem_k = torch.gather(self.mem_k, -2, top_k_indices_expanded) | |
self.mem_v = torch.gather(self.mem_v, -2, top_k_indices_expanded) | |
self.mem_attn = torch.gather(self.mem_attn, -2, top_k_indices) | |
self.mem_count = torch.gather(self.mem_count, -2, top_k_indices) | |
if self.mem_pts is not None: | |
top_k_indices_expanded = top_k_indices.unsqueeze(-1).expand(-1, -1, 256, 3) | |
self.mem_pts = torch.gather(self.mem_pts, 1, top_k_indices_expanded) | |
self.mem_imgs = torch.gather(self.mem_imgs, 1, top_k_indices_expanded) | |
num_mem_a = self.mem_k.shape[1] | |
print('Memory pruned:', num_mem_b, '->', num_mem_a) | |
import math | |
class Spann3R(nn.Module): | |
def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", | |
use_feat=False, mem_pos_enc=False, memory_dropout=0.15): | |
super(Spann3R, self).__init__() | |
# config | |
self.use_feat = use_feat | |
self.mem_pos_enc = mem_pos_enc | |
# DUSt3R | |
self.dust3r = AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, | |
enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', | |
img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -math.inf, math.inf), | |
conf_mode=('exp', 1, math.inf), landscape_only=True) | |
# Memory encoder | |
self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout) | |
self.set_attn_head() | |
def set_memory_encoder(self, enc_depth=6, enc_embed_dim=1024, out_dim=1024, enc_num_heads=16, | |
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
memory_dropout=0.15): | |
self.value_encoder = nn.ModuleList([ | |
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, | |
norm_layer=norm_layer, rope=self.dust3r.rope if self.mem_pos_enc else None) | |
for i in range(enc_depth)]) | |
self.value_norm = norm_layer(enc_embed_dim) | |
self.value_out = nn.Linear(enc_embed_dim, out_dim) | |
if not self.use_feat: | |
self.pos_patch_embed = copy.deepcopy(self.dust3r.patch_embed) | |
self.pos_patch_embed.load_state_dict(self.dust3r.patch_embed.state_dict()) | |
# Normalization layers | |
self.norm_q = nn.LayerNorm(1024) | |
self.norm_k = nn.LayerNorm(1024) | |
self.norm_v = nn.LayerNorm(1024) | |
self.mem_dropout = nn.Dropout(memory_dropout) | |
def set_attn_head(self, enc_embed_dim=1024+768, out_dim=1024): | |
self.attn_head_1 = nn.Sequential( | |
nn.Linear(enc_embed_dim, enc_embed_dim), | |
nn.GELU(), | |
nn.Linear(enc_embed_dim, out_dim) | |
) | |
self.attn_head_2 = nn.Sequential( | |
nn.Linear(enc_embed_dim, enc_embed_dim), | |
nn.GELU(), | |
nn.Linear(enc_embed_dim, out_dim) | |
) | |
def encode_image(self, view): | |
img = view['img'] | |
B = img.shape[0] | |
im_shape = view.get('true_shape', torch.tensor(img.shape[-2:])[None].repeat(B, 1)) | |
out, pos, _ = self.dust3r._encode_image(img, im_shape) | |
return out, pos, im_shape | |
def encode_image_pairs(self, view1, view2): | |
img1 = view1['img'] | |
img2 = view2['img'] | |
B = img1.shape[0] | |
shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) | |
shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) | |
out, pos, _ = self.dust3r._encode_image(torch.cat((img1, img2), dim=0), | |
torch.cat((shape1, shape2), dim=0)) | |
out, out2 = out.chunk(2, dim=0) | |
pos, pos2 = pos.chunk(2, dim=0) | |
return out, out2, pos, pos2, shape1, shape2 | |
def encode_frames(self, view1, view2, feat1, feat2, pos1, pos2, shape1, shape2): | |
if feat1 is None: | |
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(view1, view2) | |
else: | |
feat1, pos1, shape1 = feat2, pos2, shape2 | |
feat2, pos2, shape2 = self.encode_image(view2) | |
return feat1, feat2, pos1, pos2, shape1, shape2 | |
def encode_feat_key(self, feat1, feat2, num=1): | |
feat = torch.cat((feat1, feat2), dim=-1) | |
feat_k = getattr(self, f'attn_head_{num}')(feat) | |
return feat_k | |
def encode_value(self, x, pos): | |
for block in self.value_encoder: | |
x = block(x, pos) | |
x = self.value_norm(x) | |
x = self.value_out(x) | |
return x | |
def encode_cur_value(self, res1, dec1, pos1, shape1): | |
if self.use_feat: | |
cur_v = self.encode_value(dec1[-1], pos1) | |
else: | |
out, pos_v = self.pos_patch_embed(res1['pts3d'].permute(0, 3, 1, 2), true_shape=shape1) | |
cur_v = self.encode_value(out, pos_v) | |
return cur_v | |
def decode(self, feat1, pos1, feat2, pos2): | |
dec1, dec2 = self.dust3r._decoder(feat1, pos1, feat2, pos2) | |
return dec1, dec2 | |
def downstream_head(self, dec, true_shape, num=1): | |
with torch.cuda.amp.autocast(enabled=False): | |
res = self.dust3r._downstream_head(num, [tok.float() for tok in dec], true_shape) | |
return res | |
def find_initial_pair(self, graph, n_frames): | |
view1, view2, pred1, pred2 = graph['view1'], graph['view2'], graph['pred1'], graph['pred2'] | |
n_pairs = len(view1['idx']) | |
conf_matrix = torch.zeros(n_frames, n_frames) | |
for i in range(n_pairs): | |
idx1, idx2 = view1['idx'][i], view2['idx'][i] | |
conf1 = pred1['conf'][i] | |
conf2 = pred2['conf'][i] | |
conf1_sig = (conf1-1)/conf1 | |
conf2_sig = (conf2-1)/conf2 | |
conf = conf1_sig.mean() + conf2_sig.mean() | |
conf_matrix[idx1, idx2] = conf | |
pair_idx = np.unravel_index(conf_matrix.argmax(), conf_matrix.shape) | |
print(f'init pair:{pair_idx}, conf: {conf_matrix.max()}') | |
return pair_idx | |
def find_next_best_view(self, frames, idx_todo, feat_fuse, pos1, shape1): | |
best_conf = 0.0 | |
from copy import deepcopy | |
for i in idx_todo: | |
view = frames[i] | |
feat2, pos2, shape2 = self.encode_image(view) | |
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) | |
res1 = self.downstream_head(dec1, shape1, 1) | |
res2 = self.downstream_head(dec2, shape2, 2) | |
conf1 = res1['conf'] | |
conf2 = res2['conf'] | |
conf1_sig = (conf1-1)/conf1 | |
conf2_sig = (conf2-1)/conf2 | |
total_conf_mean = conf1_sig.mean() + conf2_sig.mean() | |
if total_conf_mean > best_conf: | |
best_conf = total_conf_mean | |
best_id = i | |
best_dec1 = deepcopy(dec1) | |
best_dec2 = deepcopy(dec2) | |
best_res1 = deepcopy(res1) | |
best_res2 = deepcopy(res2) | |
best_feat2 = feat2 | |
best_pos2 = pos2 | |
best_shape2 = shape2 | |
return best_id, best_dec1, best_dec2, best_res1, best_res2, best_feat2, best_pos2, best_shape2, best_conf | |
def offline_reconstruction(self, frames, graph): | |
n_frames = len(frames) | |
idx_todo = list(range(n_frames)) | |
idx_used = [] | |
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout) | |
pair_idx = self.find_initial_pair(graph, n_frames) | |
f1, f2 = frames[pair_idx[0]], frames[pair_idx[1]] | |
idx_used.append(pair_idx[0]) | |
idx_used.append(pair_idx[1]) | |
# remove those idxs from idx_todo | |
idx_todo.remove(pair_idx[0]) | |
idx_todo.remove(pair_idx[1]) | |
##### Encode frames | |
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(f1, f2) | |
feat_fuse = feat1 | |
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) | |
##### Regress pointmaps | |
with torch.cuda.amp.autocast(enabled=False): | |
res1 = self.downstream_head(dec1, shape1, 1) | |
res2 = self.downstream_head(dec2, shape2, 2) | |
##### Encode feat key | |
feat_k2 = None | |
preds = None | |
while True: | |
if feat_k2 is not None: | |
feat1 = feat2 | |
pos1, shape1 = pos2, shape2 | |
feat_fuse = sp_mem.memory_read(feat_k2, res=True) | |
id_n, dec1, dec2, res1, res2, feat2, pos2, shape2, best_conf = self.find_next_best_view(frames, idx_todo, feat_fuse, pos2, shape2) | |
idx_todo.remove(id_n) | |
idx_used.append(id_n) | |
print(f'next best view: {id_n}, conf: {best_conf}') | |
# encode feat | |
feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1) | |
feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2) | |
##### Memory update | |
cur_v = self.encode_cur_value(res1, dec1, pos1, shape1) | |
sp_mem.add_mem_check(feat_k1, cur_v+feat_k1) | |
res2['pts3d_in_other_view'] = res2.pop('pts3d') | |
if preds is None: | |
preds = [res1] | |
preds_all = [(res1, res2)] | |
else: | |
res1['pts3d_in_other_view'] = res1.pop('pts3d') | |
preds.append(res1) | |
preds_all.append((res1, res2)) | |
if len(idx_todo) == 0: | |
break | |
preds.append(res2) | |
return preds, preds_all, idx_used | |
def forward(self, frames, return_memory=False): | |
if self.training: | |
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout, attn_thresh=0) | |
else: | |
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v) | |
feat1, feat2, pos1, pos2, shape1, shape2 = None, None, None, None, None, None | |
feat_k1, feat_k2 = None, None | |
preds = None | |
preds_all = [] | |
for i in range(len(frames)): | |
if i == len(frames)-1: | |
break | |
view1 = frames[i] | |
view2 = frames[(i+1)] | |
##### Encode frames | |
# feat1: [bs, p=196, c=1024] | |
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_frames(view1, view2, feat1, feat2, pos1, pos2, shape1, shape2) | |
##### Memory readout | |
if feat_k2 is not None: | |
feat_fuse = sp_mem.memory_read(feat_k2, res=True) | |
# feat_fuse = feat_fuse + feat1 | |
else: | |
feat_fuse = feat1 | |
##### Decode features | |
# dec1[-1]: [bs, p, c=768] | |
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) | |
##### Encode feat key | |
feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1) | |
feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2) | |
##### Regress pointmaps | |
with torch.cuda.amp.autocast(enabled=False): | |
res1 = self.downstream_head(dec1, shape1, 1) | |
res2 = self.downstream_head(dec2, shape2, 2) | |
##### Memory update | |
cur_v = self.encode_cur_value(res1, dec1, pos1, shape1) | |
if self.training: | |
sp_mem.add_mem(feat_k1, cur_v+feat_k1) | |
else: | |
sp_mem.add_mem_check(feat_k1, cur_v+feat_k1) | |
res2['pts3d_in_other_view'] = res2.pop('pts3d') | |
if preds is None: | |
preds = [res1] | |
preds_all = [(res1, res2)] | |
else: | |
res1['pts3d_in_other_view'] = res1.pop('pts3d') | |
preds.append(res1) | |
preds_all.append((res1, res2)) | |
preds.append(res2) | |
if return_memory: | |
return preds, preds_all, sp_mem | |
return preds, preds_all | |