""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import argparse import json import os from math import ceil import huggingface_hub import torch.nn.functional as F import torch.onnx from unik3d.models.unik3d import UniK3D class UniK3DONNX(UniK3D): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__(config, eps) def forward(self, rgbs): B, _, H, W = rgbs.shape features, tokens = self.pixel_encoder(rgbs) inputs = {} inputs["image"] = rgbs inputs["features"] = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] inputs["tokens"] = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] outputs = self.pixel_decoder(inputs, []) outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) pts_3d = outputs["rays"] * outputs["radius"] return pts_3d, outputs["confidence"] class UniK3DONNXcam(UniK3D): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__(config, eps) def forward(self, rgbs, rays): B, _, H, W = rgbs.shape features, tokens = self.pixel_encoder(rgbs) inputs = {} inputs["image"] = rgbs inputs["rays"] = rays inputs["features"] = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] inputs["tokens"] = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] outputs = self.pixel_decoder(inputs, []) outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) pts_3d = outputs["rays"] * outputs["radius"] return pts_3d, outputs["confidence"] def export(model, path, shape=(462, 630), with_camera=False): model.eval() image = torch.rand(1, 3, *shape) dynamic_axes_in = {"rgbs": {0: "batch"}} inputs = [image] if with_camera: rays = torch.rand(1, 3, *shape) inputs.append(rays) dynamic_axes_in["rays"] = {0: "batch"} dynamic_axes_out = { "pts_3d": {0: "batch"}, "confidence": {0: "batch"}, } torch.onnx.export( model, tuple(inputs), path, input_names=list(dynamic_axes_in.keys()), output_names=list(dynamic_axes_out.keys()), opset_version=14, dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, ) print(f"Model exported to {path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX") parser.add_argument( "--backbone", type=str, default="vitl", choices=["vits", "vitb", "vitl"], help="Backbone model", ) parser.add_argument( "--shape", type=int, nargs=2, default=(462, 630), help="Input shape. No dyamic shape supported!", ) parser.add_argument( "--output-path", type=str, default="unik3d.onnx", help="Output ONNX file" ) parser.add_argument( "--with-camera", action="store_true", help="Export model that expects GT camera as unprojected rays at inference", ) args = parser.parse_args() backbone = args.backbone shape = args.shape output_path = args.output_path with_camera = args.with_camera # force shape to be multiple of 14 shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] if list(shape) != list(shape_rounded): print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") shape = shape_rounded # assumes command is from root of repo with open(os.path.join("configs", f"config_{backbone}.json")) as f: config = json.load(f) # tell DINO not to use efficient attention: not exportable config["training"]["export"] = True model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config) path = huggingface_hub.hf_hub_download( repo_id=f"lpiccinelli/unik3d-{backbone}", filename=f"pytorch_model.bin", repo_type="model", ) info = model.load_state_dict(torch.load(path), strict=False) print(f"UUniK3D_{backbone} is loaded with:") print(f"\t missing keys: {info.missing_keys}") print(f"\t additional keys: {info.unexpected_keys}") export( model=model, path=os.path.join(os.environ.get("TMPDIR", "."), output_path), shape=shape, with_camera=with_camera, )