geyongtao commited on
Commit
57214d4
·
verified ·
1 Parent(s): b80ebcd

Create catmlp_dpt_head.py

Browse files
Files changed (1) hide show
  1. catmlp_dpt_head.py +94 -0
catmlp_dpt_head.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R heads
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from mini_dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
11
+ from mini_dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
12
+ from mini_dust3r.croco.blocks import Mlp # noqa
13
+
14
+ def reg_desc(desc, mode):
15
+ if 'norm' in mode:
16
+ desc = desc / desc.norm(dim=-1, keepdim=True)
17
+ else:
18
+ raise ValueError(f"Unknown desc mode {mode}")
19
+ return desc
20
+
21
+
22
+ def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
23
+ if desc_conf_mode is None:
24
+ desc_conf_mode = conf_mode
25
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,D
26
+ res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
27
+ if conf_mode is not None:
28
+ res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
29
+ if desc_dim is not None:
30
+ start = 3 + int(conf_mode is not None)
31
+ res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
32
+ if two_confs:
33
+ res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
34
+ else:
35
+ res['desc_conf'] = res['conf'].clone()
36
+ return res
37
+
38
+
39
+ class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
40
+ """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
41
+ The input for both heads is a concatenation of Encoder and Decoder outputs
42
+ """
43
+
44
+ def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
45
+ num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
46
+ super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
47
+ dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
48
+ self.local_feat_dim = local_feat_dim
49
+
50
+ patch_size = net.patch_embed.patch_size
51
+ if isinstance(patch_size, tuple):
52
+ assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
53
+ patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
54
+ assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
55
+ patch_size = patch_size[0]
56
+ self.patch_size = patch_size
57
+
58
+ self.desc_mode = net.desc_mode
59
+ self.has_conf = has_conf
60
+ self.two_confs = net.two_confs # independent confs for 3D regr and descs
61
+ self.desc_conf_mode = net.desc_conf_mode
62
+ idim = net.enc_embed_dim + net.dec_embed_dim
63
+
64
+ self.head_local_features = Mlp(in_features=idim,
65
+ hidden_features=int(hidden_dim_factor * idim),
66
+ out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
67
+
68
+ def forward(self, decout, img_shape):
69
+ # pass through the heads
70
+ pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
71
+
72
+ # recover encoder and decoder outputs
73
+ enc_output, dec_output = decout[0], decout[-1]
74
+ cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
75
+ H, W = img_shape
76
+ B, S, D = cat_output.shape
77
+
78
+ # extract local_features
79
+ local_features = self.head_local_features(cat_output) # B,S,D
80
+ local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
81
+ local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
82
+
83
+ # post process 3D pts, descriptors and confidences
84
+ out = torch.cat([pts3d, local_features], dim=1)
85
+ if self.postprocess:
86
+ out = self.postprocess(out,
87
+ depth_mode=self.depth_mode,
88
+ conf_mode=self.conf_mode,
89
+ desc_dim=self.local_feat_dim,
90
+ desc_mode=self.desc_mode,
91
+ two_confs=self.two_confs,
92
+ desc_conf_mode=self.desc_conf_mode)
93
+ return out
94
+