File size: 4,790 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""

import argparse
import json
import os
from math import ceil

import huggingface_hub
import torch.nn.functional as F
import torch.onnx

from unik3d.models.unik3d import UniK3D


class UniK3DONNX(UniK3D):
    def __init__(
        self,
        config,
        eps: float = 1e-6,
        **kwargs,
    ):
        super().__init__(config, eps)

    def forward(self, rgbs):
        B, _, H, W = rgbs.shape
        features, tokens = self.pixel_encoder(rgbs)

        inputs = {}
        inputs["image"] = rgbs
        inputs["features"] = [
            self.stacking_fn(features[i:j]).contiguous()
            for i, j in self.slices_encoder_range
        ]
        inputs["tokens"] = [
            self.stacking_fn(tokens[i:j]).contiguous()
            for i, j in self.slices_encoder_range
        ]
        outputs = self.pixel_decoder(inputs, [])
        outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
        pts_3d = outputs["rays"] * outputs["radius"]

        return pts_3d, outputs["confidence"]


class UniK3DONNXcam(UniK3D):
    def __init__(
        self,
        config,
        eps: float = 1e-6,
        **kwargs,
    ):
        super().__init__(config, eps)

    def forward(self, rgbs, rays):
        B, _, H, W = rgbs.shape
        features, tokens = self.pixel_encoder(rgbs)

        inputs = {}
        inputs["image"] = rgbs
        inputs["rays"] = rays
        inputs["features"] = [
            self.stacking_fn(features[i:j]).contiguous()
            for i, j in self.slices_encoder_range
        ]
        inputs["tokens"] = [
            self.stacking_fn(tokens[i:j]).contiguous()
            for i, j in self.slices_encoder_range
        ]
        outputs = self.pixel_decoder(inputs, [])
        outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
        pts_3d = outputs["rays"] * outputs["radius"]

        return pts_3d, outputs["confidence"]


def export(model, path, shape=(462, 630), with_camera=False):
    model.eval()
    image = torch.rand(1, 3, *shape)
    dynamic_axes_in = {"rgbs": {0: "batch"}}
    inputs = [image]
    if with_camera:
        rays = torch.rand(1, 3, *shape)
        inputs.append(rays)
        dynamic_axes_in["rays"] = {0: "batch"}

    dynamic_axes_out = {
        "pts_3d": {0: "batch"},
        "confidence": {0: "batch"},
    }
    torch.onnx.export(
        model,
        tuple(inputs),
        path,
        input_names=list(dynamic_axes_in.keys()),
        output_names=list(dynamic_axes_out.keys()),
        opset_version=14,
        dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
    )
    print(f"Model exported to {path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX")
    parser.add_argument(
        "--backbone",
        type=str,
        default="vitl",
        choices=["vits", "vitb", "vitl"],
        help="Backbone model",
    )
    parser.add_argument(
        "--shape",
        type=int,
        nargs=2,
        default=(462, 630),
        help="Input shape. No dyamic shape supported!",
    )
    parser.add_argument(
        "--output-path", type=str, default="unik3d.onnx", help="Output ONNX file"
    )
    parser.add_argument(
        "--with-camera",
        action="store_true",
        help="Export model that expects GT camera as unprojected rays at inference",
    )
    args = parser.parse_args()

    backbone = args.backbone
    shape = args.shape
    output_path = args.output_path
    with_camera = args.with_camera

    # force shape to be multiple of 14
    shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
    if list(shape) != list(shape_rounded):
        print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
        shape = shape_rounded

    # assumes command is from root of repo
    with open(os.path.join("configs", f"config_{backbone}.json")) as f:
        config = json.load(f)

    # tell DINO not to use efficient attention: not exportable
    config["training"]["export"] = True

    model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config)
    path = huggingface_hub.hf_hub_download(
        repo_id=f"lpiccinelli/unik3d-{backbone}",
        filename=f"pytorch_model.bin",
        repo_type="model",
    )
    info = model.load_state_dict(torch.load(path), strict=False)
    print(f"UUniK3D_{backbone} is loaded with:")
    print(f"\t missing keys: {info.missing_keys}")
    print(f"\t additional keys: {info.unexpected_keys}")

    export(
        model=model,
        path=os.path.join(os.environ.get("TMPDIR", "."), output_path),
        shape=shape,
        with_camera=with_camera,
    )