Luigi Piccinelli
init demo
1ea89dd
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import json
import os
from math import ceil
import huggingface_hub
import torch.nn.functional as F
import torch.onnx
from unik3d.models.unik3d import UniK3D
class UniK3DONNX(UniK3D):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"]
class UniK3DONNXcam(UniK3D):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs, rays):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["rays"] = rays
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"]
def export(model, path, shape=(462, 630), with_camera=False):
model.eval()
image = torch.rand(1, 3, *shape)
dynamic_axes_in = {"rgbs": {0: "batch"}}
inputs = [image]
if with_camera:
rays = torch.rand(1, 3, *shape)
inputs.append(rays)
dynamic_axes_in["rays"] = {0: "batch"}
dynamic_axes_out = {
"pts_3d": {0: "batch"},
"confidence": {0: "batch"},
}
torch.onnx.export(
model,
tuple(inputs),
path,
input_names=list(dynamic_axes_in.keys()),
output_names=list(dynamic_axes_out.keys()),
opset_version=14,
dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
)
print(f"Model exported to {path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX")
parser.add_argument(
"--backbone",
type=str,
default="vitl",
choices=["vits", "vitb", "vitl"],
help="Backbone model",
)
parser.add_argument(
"--shape",
type=int,
nargs=2,
default=(462, 630),
help="Input shape. No dyamic shape supported!",
)
parser.add_argument(
"--output-path", type=str, default="unik3d.onnx", help="Output ONNX file"
)
parser.add_argument(
"--with-camera",
action="store_true",
help="Export model that expects GT camera as unprojected rays at inference",
)
args = parser.parse_args()
backbone = args.backbone
shape = args.shape
output_path = args.output_path
with_camera = args.with_camera
# force shape to be multiple of 14
shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
if list(shape) != list(shape_rounded):
print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
shape = shape_rounded
# assumes command is from root of repo
with open(os.path.join("configs", f"config_{backbone}.json")) as f:
config = json.load(f)
# tell DINO not to use efficient attention: not exportable
config["training"]["export"] = True
model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config)
path = huggingface_hub.hf_hub_download(
repo_id=f"lpiccinelli/unik3d-{backbone}",
filename=f"pytorch_model.bin",
repo_type="model",
)
info = model.load_state_dict(torch.load(path), strict=False)
print(f"UUniK3D_{backbone} is loaded with:")
print(f"\t missing keys: {info.missing_keys}")
print(f"\t additional keys: {info.unexpected_keys}")
export(
model=model,
path=os.path.join(os.environ.get("TMPDIR", "."), output_path),
shape=shape,
with_camera=with_camera,
)