StableRecon / spann3r /model.py
Stable-X's picture
feat: Clean codes
2bca3f5
raw
history blame
18.5 kB
import copy
import torch
import numpy as np
import torch.nn as nn
from functools import partial
from torch.nn import functional as F
from croco.models.blocks import Block
from dust3r.model import AsymmetricCroCo3DStereo
class SpatialMemory():
def __init__(self, norm_q, norm_k, norm_v, mem_dropout=None,
long_mem_size=4000, work_mem_size=5,
attn_thresh=5e-4, sim_thresh=0.95,
save_attn=False):
self.norm_q = norm_q
self.norm_k = norm_k
self.norm_v = norm_v
self.mem_dropout = mem_dropout
self.attn_thresh = attn_thresh
self.long_mem_size = long_mem_size
self.work_mem_size = work_mem_size
self.top_k = long_mem_size
self.save_attn = save_attn
self.sim_thresh = sim_thresh
self.init_mem()
def init_mem(self):
self.mem_k = None
self.mem_v = None
self.mem_c = None
self.mem_count = None
self.mem_attn = None
self.mem_pts = None
self.mem_imgs = None
self.lm = 0
self.wm = 0
if self.save_attn:
self.attn_vis = None
def add_mem_k(self, feat):
if self.mem_k is None:
self.mem_k = feat
else:
self.mem_k = torch.cat((self.mem_k, feat), dim=1)
return self.mem_k
def add_mem_v(self, feat):
if self.mem_v is None:
self.mem_v = feat
else:
self.mem_v = torch.cat((self.mem_v, feat), dim=1)
return self.mem_v
def add_mem_c(self, feat):
if self.mem_c is None:
self.mem_c = feat
else:
self.mem_c = torch.cat((self.mem_c, feat), dim=1)
return self.mem_c
def add_mem_pts(self, pts_cur):
if pts_cur is not None:
if self.mem_pts is None:
self.mem_pts = pts_cur
else:
self.mem_pts = torch.cat((self.mem_pts, pts_cur), dim=1)
def add_mem_img(self, img_cur):
if img_cur is not None:
if self.mem_imgs is None:
self.mem_imgs = img_cur
else:
self.mem_imgs = torch.cat((self.mem_imgs, img_cur), dim=1)
def add_mem(self, feat_k, feat_v, pts_cur=None, img_cur=None):
if self.mem_count is None:
self.mem_count = torch.zeros_like(feat_k[:, :, :1])
self.mem_attn = torch.zeros_like(feat_k[:, :, :1])
else:
self.mem_count += 1
self.mem_count = torch.cat((self.mem_count, torch.zeros_like(feat_k[:, :, :1])), dim=1)
self.mem_attn = torch.cat((self.mem_attn, torch.zeros_like(feat_k[:, :, :1])), dim=1)
self.add_mem_k(feat_k)
self.add_mem_v(feat_v)
self.add_mem_pts(pts_cur)
self.add_mem_img(img_cur)
def check_sim(self, feat_k, thresh=0.7):
# Do correlation with working memory
if self.mem_k is None or thresh==1.0:
return False
wmem_size = self.wm * 196
# wm: BS, T, 196, C
wm = self.mem_k[:, -wmem_size:].reshape(self.mem_k.shape[0], -1, 196, self.mem_k.shape[-1])
feat_k_norm = F.normalize(feat_k, p=2, dim=-1)
wm_norm = F.normalize(wm, p=2, dim=-1)
corr = torch.einsum('bpc,btpc->btp', feat_k_norm, wm_norm)
mean_corr = torch.mean(corr, dim=-1)
if mean_corr.max() > thresh:
print('Similarity detected:', mean_corr.max())
return True
return False
def add_mem_check(self, feat_k, feat_v, pts_cur=None, img_cur=None):
if self.check_sim(feat_k, thresh=self.sim_thresh):
return
self.add_mem(feat_k, feat_v, pts_cur, img_cur)
self.wm += 1
if self.wm > self.work_mem_size:
self.wm -= 1
if self.long_mem_size == 0:
self.mem_k = self.mem_k[:, 196:]
self.mem_v = self.mem_v[:, 196:]
self.mem_count = self.mem_count[:, 196:]
self.mem_attn = self.mem_attn[:, 196:]
print('Memory pruned:', self.mem_k.shape)
else:
self.lm += 196 # TODO: Change this to the actual size of the memory bank
if self.lm > self.long_mem_size:
self.memory_prune()
self.lm = self.top_k - self.wm * 196
def memory_read(self, feat, res=True):
'''
Params:
- feat: [bs, p, c]
- mem_k: [bs, t, p, c]
- mem_v: [bs, t, p, c]
- mem_c: [bs, t, p, 1]
'''
affinity = torch.einsum('bpc,bxc->bpx', self.norm_q(feat), self.norm_k(self.mem_k.reshape(self.mem_k.shape[0], -1, self.mem_k.shape[-1])))
affinity /= torch.sqrt(torch.tensor(feat.shape[-1]).float())
if self.mem_c is not None:
affinity = affinity * self.mem_c.view(self.mem_c.shape[0], 1, -1)
attn = torch.softmax(affinity, dim=-1)
if self.save_attn:
if self.attn_vis is None:
self.attn_vis = attn.reshape(-1)
else:
self.attn_vis = torch.cat((self.attn_vis, attn.reshape(-1)), dim=0)
if self.mem_dropout is not None:
attn = self.mem_dropout(attn)
if self.attn_thresh > 0:
attn[attn<self.attn_thresh] = 0
attn = attn / attn.sum(dim=-1, keepdim=True)
out = torch.einsum('bpx,bxc->bpc', attn, self.norm_v(self.mem_v.reshape(self.mem_v.shape[0], -1, self.mem_v.shape[-1])))
if res:
out = out + feat
total_attn = torch.sum(attn, dim=-2)
self.mem_attn += total_attn[..., None]
return out
def memory_prune(self):
weights = self.mem_attn / self.mem_count
weights[self.mem_count<self.work_mem_size+5] = 1e8
num_mem_b = self.mem_k.shape[1]
top_k_values, top_k_indices = torch.topk(weights, self.top_k, dim=1)
top_k_indices_expanded = top_k_indices.expand(-1, -1, self.mem_k.size(-1))
self.mem_k = torch.gather(self.mem_k, -2, top_k_indices_expanded)
self.mem_v = torch.gather(self.mem_v, -2, top_k_indices_expanded)
self.mem_attn = torch.gather(self.mem_attn, -2, top_k_indices)
self.mem_count = torch.gather(self.mem_count, -2, top_k_indices)
if self.mem_pts is not None:
top_k_indices_expanded = top_k_indices.unsqueeze(-1).expand(-1, -1, 256, 3)
self.mem_pts = torch.gather(self.mem_pts, 1, top_k_indices_expanded)
self.mem_imgs = torch.gather(self.mem_imgs, 1, top_k_indices_expanded)
num_mem_a = self.mem_k.shape[1]
print('Memory pruned:', num_mem_b, '->', num_mem_a)
import math
class Spann3R(nn.Module):
def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
use_feat=False, mem_pos_enc=False, memory_dropout=0.15):
super(Spann3R, self).__init__()
# config
self.use_feat = use_feat
self.mem_pos_enc = mem_pos_enc
# DUSt3R
self.dust3r = AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768,
enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R',
img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -math.inf, math.inf),
conf_mode=('exp', 1, math.inf), landscape_only=True)
# Memory encoder
self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout)
self.set_attn_head()
def set_memory_encoder(self, enc_depth=6, enc_embed_dim=1024, out_dim=1024, enc_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6),
memory_dropout=0.15):
self.value_encoder = nn.ModuleList([
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True,
norm_layer=norm_layer, rope=self.dust3r.rope if self.mem_pos_enc else None)
for i in range(enc_depth)])
self.value_norm = norm_layer(enc_embed_dim)
self.value_out = nn.Linear(enc_embed_dim, out_dim)
if not self.use_feat:
self.pos_patch_embed = copy.deepcopy(self.dust3r.patch_embed)
self.pos_patch_embed.load_state_dict(self.dust3r.patch_embed.state_dict())
# Normalization layers
self.norm_q = nn.LayerNorm(1024)
self.norm_k = nn.LayerNorm(1024)
self.norm_v = nn.LayerNorm(1024)
self.mem_dropout = nn.Dropout(memory_dropout)
def set_attn_head(self, enc_embed_dim=1024+768, out_dim=1024):
self.attn_head_1 = nn.Sequential(
nn.Linear(enc_embed_dim, enc_embed_dim),
nn.GELU(),
nn.Linear(enc_embed_dim, out_dim)
)
self.attn_head_2 = nn.Sequential(
nn.Linear(enc_embed_dim, enc_embed_dim),
nn.GELU(),
nn.Linear(enc_embed_dim, out_dim)
)
def encode_image(self, view):
img = view['img']
B = img.shape[0]
im_shape = view.get('true_shape', torch.tensor(img.shape[-2:])[None].repeat(B, 1))
out, pos, _ = self.dust3r._encode_image(img, im_shape)
return out, pos, im_shape
def encode_image_pairs(self, view1, view2):
img1 = view1['img']
img2 = view2['img']
B = img1.shape[0]
shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))
shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))
out, pos, _ = self.dust3r._encode_image(torch.cat((img1, img2), dim=0),
torch.cat((shape1, shape2), dim=0))
out, out2 = out.chunk(2, dim=0)
pos, pos2 = pos.chunk(2, dim=0)
return out, out2, pos, pos2, shape1, shape2
def encode_frames(self, view1, view2, feat1, feat2, pos1, pos2, shape1, shape2):
if feat1 is None:
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(view1, view2)
else:
feat1, pos1, shape1 = feat2, pos2, shape2
feat2, pos2, shape2 = self.encode_image(view2)
return feat1, feat2, pos1, pos2, shape1, shape2
def encode_feat_key(self, feat1, feat2, num=1):
feat = torch.cat((feat1, feat2), dim=-1)
feat_k = getattr(self, f'attn_head_{num}')(feat)
return feat_k
def encode_value(self, x, pos):
for block in self.value_encoder:
x = block(x, pos)
x = self.value_norm(x)
x = self.value_out(x)
return x
def encode_cur_value(self, res1, dec1, pos1, shape1):
if self.use_feat:
cur_v = self.encode_value(dec1[-1], pos1)
else:
out, pos_v = self.pos_patch_embed(res1['pts3d'].permute(0, 3, 1, 2), true_shape=shape1)
cur_v = self.encode_value(out, pos_v)
return cur_v
def decode(self, feat1, pos1, feat2, pos2):
dec1, dec2 = self.dust3r._decoder(feat1, pos1, feat2, pos2)
return dec1, dec2
def downstream_head(self, dec, true_shape, num=1):
with torch.cuda.amp.autocast(enabled=False):
res = self.dust3r._downstream_head(num, [tok.float() for tok in dec], true_shape)
return res
def find_initial_pair(self, graph, n_frames):
view1, view2, pred1, pred2 = graph['view1'], graph['view2'], graph['pred1'], graph['pred2']
n_pairs = len(view1['idx'])
conf_matrix = torch.zeros(n_frames, n_frames)
for i in range(n_pairs):
idx1, idx2 = view1['idx'][i], view2['idx'][i]
conf1 = pred1['conf'][i]
conf2 = pred2['conf'][i]
conf1_sig = (conf1-1)/conf1
conf2_sig = (conf2-1)/conf2
conf = conf1_sig.mean() + conf2_sig.mean()
conf_matrix[idx1, idx2] = conf
pair_idx = np.unravel_index(conf_matrix.argmax(), conf_matrix.shape)
print(f'init pair:{pair_idx}, conf: {conf_matrix.max()}')
return pair_idx
def find_next_best_view(self, frames, idx_todo, feat_fuse, pos1, shape1):
best_conf = 0.0
from copy import deepcopy
for i in idx_todo:
view = frames[i]
feat2, pos2, shape2 = self.encode_image(view)
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2)
res1 = self.downstream_head(dec1, shape1, 1)
res2 = self.downstream_head(dec2, shape2, 2)
conf1 = res1['conf']
conf2 = res2['conf']
conf1_sig = (conf1-1)/conf1
conf2_sig = (conf2-1)/conf2
total_conf_mean = conf1_sig.mean() + conf2_sig.mean()
if total_conf_mean > best_conf:
best_conf = total_conf_mean
best_id = i
best_dec1 = deepcopy(dec1)
best_dec2 = deepcopy(dec2)
best_res1 = deepcopy(res1)
best_res2 = deepcopy(res2)
best_feat2 = feat2
best_pos2 = pos2
best_shape2 = shape2
return best_id, best_dec1, best_dec2, best_res1, best_res2, best_feat2, best_pos2, best_shape2, best_conf
def offline_reconstruction(self, frames, graph):
n_frames = len(frames)
idx_todo = list(range(n_frames))
idx_used = []
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout)
pair_idx = self.find_initial_pair(graph, n_frames)
f1, f2 = frames[pair_idx[0]], frames[pair_idx[1]]
idx_used.append(pair_idx[0])
idx_used.append(pair_idx[1])
# remove those idxs from idx_todo
idx_todo.remove(pair_idx[0])
idx_todo.remove(pair_idx[1])
##### Encode frames
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(f1, f2)
feat_fuse = feat1
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2)
##### Regress pointmaps
with torch.cuda.amp.autocast(enabled=False):
res1 = self.downstream_head(dec1, shape1, 1)
res2 = self.downstream_head(dec2, shape2, 2)
##### Encode feat key
feat_k2 = None
preds = None
while True:
if feat_k2 is not None:
feat1 = feat2
pos1, shape1 = pos2, shape2
feat_fuse = sp_mem.memory_read(feat_k2, res=True)
id_n, dec1, dec2, res1, res2, feat2, pos2, shape2, best_conf = self.find_next_best_view(frames, idx_todo, feat_fuse, pos2, shape2)
idx_todo.remove(id_n)
idx_used.append(id_n)
print(f'next best view: {id_n}, conf: {best_conf}')
# encode feat
feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1)
feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2)
##### Memory update
cur_v = self.encode_cur_value(res1, dec1, pos1, shape1)
sp_mem.add_mem_check(feat_k1, cur_v+feat_k1)
res2['pts3d_in_other_view'] = res2.pop('pts3d')
if preds is None:
preds = [res1]
preds_all = [(res1, res2)]
else:
res1['pts3d_in_other_view'] = res1.pop('pts3d')
preds.append(res1)
preds_all.append((res1, res2))
if len(idx_todo) == 0:
break
preds.append(res2)
return preds, preds_all, idx_used
def forward(self, frames, return_memory=False):
if self.training:
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout, attn_thresh=0)
else:
sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v)
feat1, feat2, pos1, pos2, shape1, shape2 = None, None, None, None, None, None
feat_k1, feat_k2 = None, None
preds = None
preds_all = []
for i in range(len(frames)):
if i == len(frames)-1:
break
view1 = frames[i]
view2 = frames[(i+1)]
##### Encode frames
# feat1: [bs, p=196, c=1024]
feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_frames(view1, view2, feat1, feat2, pos1, pos2, shape1, shape2)
##### Memory readout
if feat_k2 is not None:
feat_fuse = sp_mem.memory_read(feat_k2, res=True)
# feat_fuse = feat_fuse + feat1
else:
feat_fuse = feat1
##### Decode features
# dec1[-1]: [bs, p, c=768]
dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2)
##### Encode feat key
feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1)
feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2)
##### Regress pointmaps
with torch.cuda.amp.autocast(enabled=False):
res1 = self.downstream_head(dec1, shape1, 1)
res2 = self.downstream_head(dec2, shape2, 2)
##### Memory update
cur_v = self.encode_cur_value(res1, dec1, pos1, shape1)
if self.training:
sp_mem.add_mem(feat_k1, cur_v+feat_k1)
else:
sp_mem.add_mem_check(feat_k1, cur_v+feat_k1)
res2['pts3d_in_other_view'] = res2.pop('pts3d')
if preds is None:
preds = [res1]
preds_all = [(res1, res2)]
else:
res1['pts3d_in_other_view'] = res1.pop('pts3d')
preds.append(res1)
preds_all.append((res1, res2))
preds.append(res2)
if return_memory:
return preds, preds_all, sp_mem
return preds, preds_all