Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,790 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import json
import os
from math import ceil
import huggingface_hub
import torch.nn.functional as F
import torch.onnx
from unik3d.models.unik3d import UniK3D
class UniK3DONNX(UniK3D):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"]
class UniK3DONNXcam(UniK3D):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs, rays):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["rays"] = rays
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"]
def export(model, path, shape=(462, 630), with_camera=False):
model.eval()
image = torch.rand(1, 3, *shape)
dynamic_axes_in = {"rgbs": {0: "batch"}}
inputs = [image]
if with_camera:
rays = torch.rand(1, 3, *shape)
inputs.append(rays)
dynamic_axes_in["rays"] = {0: "batch"}
dynamic_axes_out = {
"pts_3d": {0: "batch"},
"confidence": {0: "batch"},
}
torch.onnx.export(
model,
tuple(inputs),
path,
input_names=list(dynamic_axes_in.keys()),
output_names=list(dynamic_axes_out.keys()),
opset_version=14,
dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
)
print(f"Model exported to {path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX")
parser.add_argument(
"--backbone",
type=str,
default="vitl",
choices=["vits", "vitb", "vitl"],
help="Backbone model",
)
parser.add_argument(
"--shape",
type=int,
nargs=2,
default=(462, 630),
help="Input shape. No dyamic shape supported!",
)
parser.add_argument(
"--output-path", type=str, default="unik3d.onnx", help="Output ONNX file"
)
parser.add_argument(
"--with-camera",
action="store_true",
help="Export model that expects GT camera as unprojected rays at inference",
)
args = parser.parse_args()
backbone = args.backbone
shape = args.shape
output_path = args.output_path
with_camera = args.with_camera
# force shape to be multiple of 14
shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
if list(shape) != list(shape_rounded):
print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
shape = shape_rounded
# assumes command is from root of repo
with open(os.path.join("configs", f"config_{backbone}.json")) as f:
config = json.load(f)
# tell DINO not to use efficient attention: not exportable
config["training"]["export"] = True
model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config)
path = huggingface_hub.hf_hub_download(
repo_id=f"lpiccinelli/unik3d-{backbone}",
filename=f"pytorch_model.bin",
repo_type="model",
)
info = model.load_state_dict(torch.load(path), strict=False)
print(f"UUniK3D_{backbone} is loaded with:")
print(f"\t missing keys: {info.missing_keys}")
print(f"\t additional keys: {info.unexpected_keys}")
export(
model=model,
path=os.path.join(os.environ.get("TMPDIR", "."), output_path),
shape=shape,
with_camera=with_camera,
)
|