|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
import torch |
|
import os |
|
from packaging import version |
|
import huggingface_hub |
|
|
|
from .utils.misc import ( |
|
fill_default_args, |
|
freeze_all_params, |
|
is_symmetrized, |
|
interleave, |
|
transpose_to_landscape, |
|
) |
|
from .heads import head_factory |
|
from mini_dust3r.patch_embed import get_patch_embed |
|
|
|
from mini_dust3r.croco.croco import CroCoNet |
|
|
|
inf = float("inf") |
|
|
|
hf_version_number = huggingface_hub.__version__ |
|
assert version.parse(hf_version_number) >= version.parse( |
|
"0.22.0" |
|
), "Outdated huggingface_hub version, please reinstall requirements.txt" |
|
|
|
|
|
def load_model(model_path, device, verbose=True): |
|
if verbose: |
|
print("... loading model from", model_path) |
|
ckpt = torch.load(model_path, map_location="cpu") |
|
args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") |
|
if "landscape_only" not in args: |
|
args = args[:-1] + ", landscape_only=False)" |
|
else: |
|
args = args.replace(" ", "").replace( |
|
"landscape_only=True", "landscape_only=False" |
|
) |
|
assert "landscape_only=False" in args |
|
if verbose: |
|
print(f"instantiating : {args}") |
|
net = eval(args) |
|
s = net.load_state_dict(ckpt["model"], strict=False) |
|
if verbose: |
|
print(s) |
|
return net.to(device) |
|
|
|
|
|
class AsymmetricCroCo3DStereo( |
|
CroCoNet, |
|
huggingface_hub.PyTorchModelHubMixin, |
|
library_name="dust3r", |
|
repo_url="https://github.com/naver/dust3r", |
|
tags=["image-to-3d"], |
|
): |
|
"""Two siamese encoders, followed by two decoders. |
|
The goal is to output 3d points directly, both images in view1's frame |
|
(hence the asymmetry). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_mode="pts3d", |
|
head_type="linear", |
|
depth_mode=("exp", -inf, inf), |
|
conf_mode=("exp", 1, inf), |
|
freeze="none", |
|
landscape_only=True, |
|
patch_embed_cls="PatchEmbedDust3R", |
|
**croco_kwargs, |
|
): |
|
self.patch_embed_cls = patch_embed_cls |
|
self.croco_args = fill_default_args(croco_kwargs, super().__init__) |
|
super().__init__(**croco_kwargs) |
|
|
|
|
|
self.dec_blocks2 = deepcopy(self.dec_blocks) |
|
self.set_downstream_head( |
|
output_mode, |
|
head_type, |
|
landscape_only, |
|
depth_mode, |
|
conf_mode, |
|
**croco_kwargs, |
|
) |
|
self.set_freeze(freeze) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
if os.path.isfile(pretrained_model_name_or_path): |
|
return load_model(pretrained_model_name_or_path, device="cpu") |
|
else: |
|
return super(AsymmetricCroCo3DStereo, cls).from_pretrained( |
|
pretrained_model_name_or_path, **kw |
|
) |
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
self.patch_embed = get_patch_embed( |
|
self.patch_embed_cls, img_size, patch_size, enc_embed_dim |
|
) |
|
|
|
def load_state_dict(self, ckpt, **kw): |
|
|
|
new_ckpt = dict(ckpt) |
|
if not any(k.startswith("dec_blocks2") for k in ckpt): |
|
for key, value in ckpt.items(): |
|
if key.startswith("dec_blocks"): |
|
new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value |
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
|
def set_freeze(self, freeze): |
|
self.freeze = freeze |
|
to_be_frozen = { |
|
"none": [], |
|
"mask": [self.mask_token], |
|
"encoder": [self.mask_token, self.patch_embed, self.enc_blocks], |
|
} |
|
freeze_all_params(to_be_frozen[freeze]) |
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
"""No prediction head""" |
|
return |
|
|
|
def set_downstream_head( |
|
self, |
|
output_mode, |
|
head_type, |
|
landscape_only, |
|
depth_mode, |
|
conf_mode, |
|
patch_size, |
|
img_size, |
|
**kw, |
|
): |
|
assert ( |
|
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 |
|
), f"{img_size=} must be multiple of {patch_size=}" |
|
self.output_mode = output_mode |
|
self.head_type = head_type |
|
self.depth_mode = depth_mode |
|
self.conf_mode = conf_mode |
|
|
|
self.downstream_head1 = head_factory( |
|
head_type, output_mode, self, has_conf=bool(conf_mode) |
|
) |
|
self.downstream_head2 = head_factory( |
|
head_type, output_mode, self, has_conf=bool(conf_mode) |
|
) |
|
|
|
self.head1 = transpose_to_landscape( |
|
self.downstream_head1, activate=landscape_only |
|
) |
|
self.head2 = transpose_to_landscape( |
|
self.downstream_head2, activate=landscape_only |
|
) |
|
|
|
def _encode_image(self, image, true_shape): |
|
|
|
x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
|
|
|
assert self.enc_pos_embed is None |
|
|
|
|
|
for blk in self.enc_blocks: |
|
x = blk(x, pos) |
|
|
|
x = self.enc_norm(x) |
|
return x, pos, None |
|
|
|
def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): |
|
if img1.shape[-2:] == img2.shape[-2:]: |
|
out, pos, _ = self._encode_image( |
|
torch.cat((img1, img2), dim=0), |
|
torch.cat((true_shape1, true_shape2), dim=0), |
|
) |
|
out, out2 = out.chunk(2, dim=0) |
|
pos, pos2 = pos.chunk(2, dim=0) |
|
else: |
|
out, pos, _ = self._encode_image(img1, true_shape1) |
|
out2, pos2, _ = self._encode_image(img2, true_shape2) |
|
return out, out2, pos, pos2 |
|
|
|
def _encode_symmetrized(self, view1, view2): |
|
img1 = view1["img"] |
|
img2 = view2["img"] |
|
B = img1.shape[0] |
|
|
|
shape1 = view1.get( |
|
"true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) |
|
) |
|
shape2 = view2.get( |
|
"true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) |
|
) |
|
|
|
|
|
if is_symmetrized(view1, view2): |
|
|
|
feat1, feat2, pos1, pos2 = self._encode_image_pairs( |
|
img1[::2], img2[::2], shape1[::2], shape2[::2] |
|
) |
|
feat1, feat2 = interleave(feat1, feat2) |
|
pos1, pos2 = interleave(pos1, pos2) |
|
else: |
|
feat1, feat2, pos1, pos2 = self._encode_image_pairs( |
|
img1, img2, shape1, shape2 |
|
) |
|
|
|
return (shape1, shape2), (feat1, feat2), (pos1, pos2) |
|
|
|
def _decoder(self, f1, pos1, f2, pos2): |
|
final_output = [(f1, f2)] |
|
|
|
|
|
f1 = self.decoder_embed(f1) |
|
f2 = self.decoder_embed(f2) |
|
|
|
final_output.append((f1, f2)) |
|
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): |
|
|
|
f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) |
|
|
|
f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) |
|
|
|
final_output.append((f1, f2)) |
|
|
|
|
|
del final_output[1] |
|
final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) |
|
return zip(*final_output) |
|
|
|
def _downstream_head(self, head_num, decout, img_shape): |
|
B, S, D = decout[-1].shape |
|
|
|
head = getattr(self, f"head{head_num}") |
|
return head(decout, img_shape) |
|
|
|
def forward(self, view1, view2): |
|
|
|
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized( |
|
view1, view2 |
|
) |
|
|
|
|
|
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) |
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) |
|
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) |
|
|
|
res2["pts3d_in_other_view"] = res2.pop( |
|
"pts3d" |
|
) |
|
return res1, res2 |