File size: 13,785 Bytes
0fd2f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# ---- INT8 (optional) ----
from demo_utils.vae import (
    VAEDecoderWrapperSingle,                         # main nn.Module
    ZERO_VAE_CACHE           # helper constants shipped with your code base
)
import pycuda.driver as cuda          # ← add
import pycuda.autoinit  # noqa

import sys
from pathlib import Path

import torch
import tensorrt as trt

from utils.dataset import ShardingLMDBDataset

data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=0
)

# ─────────────────────────────────────────────────────────
# 1️⃣  Bring the PyTorch model into scope
#     (all code you pasted lives in `vae_decoder.py`)
# ─────────────────────────────────────────────────────────

# --- dummy tensors (exact shapes you posted) ---
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
dummy_cache_input = [
    torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
    for s in ZERO_VAE_CACHE               # keep exactly the same ordering
]
inputs = [dummy_input, is_first_frame, *dummy_cache_input]

# ─────────────────────────────────────────────────────────
# 2️⃣  Export β†’ ONNX
# ─────────────────────────────────────────────────────────
model = VAEDecoderWrapperSingle().half().cuda().eval()

vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
    if 'decoder.' in key or 'conv2' in key:
        decoder_state_dict[key] = value
model.load_state_dict(decoder_state_dict)
model = model.half().cuda().eval()                          # only batch dim dynamic

onnx_path = Path("vae_decoder.onnx")
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
all_inputs_names = ["z", "use_cache"] + feat_names

with torch.inference_mode():
    torch.onnx.export(
        model,
        tuple(inputs),                                        # must be a tuple
        onnx_path.as_posix(),
        input_names=all_inputs_names,
        output_names=["rgb_out", "cache_out"],
        opset_version=17,
        do_constant_folding=True,
        dynamo=True
    )
print(f"βœ…  ONNX graph saved to {onnx_path.resolve()}")

# (Optional) quick sanity-check with ONNX-Runtime
try:
    import onnxruntime as ort
    sess = ort.InferenceSession(onnx_path.as_posix(),
                                providers=["CUDAExecutionProvider"])
    ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
    _ = sess.run(None, ort_inputs)
    print("βœ…  ONNX graph is executable")
except Exception as e:
    print("⚠️  ONNX check failed:", e)

# ─────────────────────────────────────────────────────────
# 3️⃣  Build the TensorRT engine
# ─────────────────────────────────────────────────────────
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

with open(onnx_path, "rb") as f:
    if not parser.parse(f.read()):
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        sys.exit("❌  ONNX β†’ TRT parsing failed")

config = builder.create_builder_config()


def set_workspace(config, bytes_):
    """Version-agnostic workspace limit."""
    if hasattr(config, "max_workspace_size"):                # TRT 8 / 9
        config.max_workspace_size = bytes_
    else:                                                    # TRT 10+
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)


# …
config = builder.create_builder_config()
set_workspace(config, 4 << 30)          # 4 GB
# 4 GB

if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)

# ---- INT8 (optional) ----
# provide a calibrator if you need an INT8 engine; comment this
# block if you only care about FP16.
# ─────────────────────────────────────────────────────────
# helper: version-agnostic workspace limit
# ─────────────────────────────────────────────────────────


def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
    """
    TRT < 10.x  β†’  config.max_workspace_size
    TRT β‰₯ 10.x  β†’  config.set_memory_pool_limit(...)
    """
    if hasattr(config, "max_workspace_size"):                     # TRT 8 / 9
        config.max_workspace_size = bytes_
    else:                                                         # TRT 10+
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
                                     bytes_)

# ─────────────────────────────────────────────────────────
# (optional) INT-8 calibrator
# ─────────────────────────────────────────────────────────
# β€Ό Only keep this block if you really need INT-8 β€Ό                      # gracefully skip if PyCUDA not present


class VAECalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, loader, cache="calibration.cache", max_batches=10):
        super().__init__()
        self.loader = iter(loader)
        self.batch_size = loader.batch_size or 1
        self.max_batches = max_batches
        self.count = 0
        self.cache_file = cache
        self.stream = cuda.Stream()
        self.dev_ptrs = {}

    # --- TRT 10 needs BOTH spellings ---
    def get_batch_size(self):
        return self.batch_size

    def getBatchSize(self):
        return self.batch_size

    def get_batch(self, names):
        if self.count >= self.max_batches:
            return None

        # Randomly sample a number from 1 to 10
        import random
        vae_idx = random.randint(0, 10)
        data = next(self.loader)

        latent = data['ode_latent'][0][:, :1]
        is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
        feat_cache = ZERO_VAE_CACHE
        for i in range(vae_idx):
            inputs = [latent, is_first_frame, *feat_cache]
            with torch.inference_mode():
                outputs = model(*inputs)
            latent = data['ode_latent'][0][:, i + 1:i + 2]
            is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
            feat_cache = outputs[1:]

        # -------- ensure context is current --------
        z_np = latent.cpu().numpy().astype('float32')

        ptrs = []                # list[int] – one entry per name
        for name in names:         # <-- match TRT's binding order
            if name == "z":
                arr = z_np
            elif name == "use_cache":
                arr = is_first_frame.cpu().numpy().astype('float32')
            else:
                idx = int(name.split('_')[-1])   # "vae_cache_17" -> 17
                arr = feat_cache[idx].cpu().numpy().astype('float32')

            if name not in self.dev_ptrs:
                self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)

            cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
            ptrs.append(int(self.dev_ptrs[name]))   # ***int() is required***

        self.stream.synchronize()
        self.count += 1
        print(f"Calibration batch {self.count}/{self.max_batches}")
        return ptrs

    # --- calibration-cache helpers (both spellings) ---
    def read_calibration_cache(self):
        try:
            with open(self.cache_file, "rb") as f:
                return f.read()
        except FileNotFoundError:
            return None

    def readCalibrationCache(self):
        return self.read_calibration_cache()

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

    def writeCalibrationCache(self, cache):
        self.write_calibration_cache(cache)


# ─────────────────────────────────────────────────────────
# Builder-config + optimisation profile
# ─────────────────────────────────────────────────────────
config = builder.create_builder_config()
set_workspace(config, 4 << 30)                    # 4 GB

# β–Ί enable FP16 if possible
if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)

# β–Ί enable INT-8  (delete this block if you don’t need it)
if cuda is not None:
    config.set_flag(trt.BuilderFlag.INT8)
    # supply any representative batch you like – here we reuse the latent z
    calib = VAECalibrator(dataloader)
    # TRT-10 renamed the setter:
    if hasattr(config, "set_int8_calibrator"):    # TRT 10+
        config.set_int8_calibrator(calib)
    else:                                         # TRT ≀ 9
        config.int8_calibrator = calib

# ---- optimisation profile ----
profile = builder.create_optimization_profile()
profile.set_shape(all_inputs_names[0],            # latent z
                  min=(1, 1, 16, 60, 104),
                  opt=(1, 1, 16, 60, 104),
                  max=(1, 1, 16, 60, 104))
profile.set_shape("use_cache",               # scalar flag
                  min=(1,), opt=(1,), max=(1,))
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
    profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)

config.add_optimization_profile(profile)

# ─────────────────────────────────────────────────────────
# Build the engine  (API changed in TRT-10)
# ─────────────────────────────────────────────────────────
print("βš™οΈ  Building engine … (can take a minute)")

if hasattr(builder, "build_serialized_network"):          # TRT 10+
    serialized_engine = builder.build_serialized_network(network, config)
    assert serialized_engine is not None, "build_serialized_network() failed"
    plan_path = Path("checkpoints/vae_decoder_int8.trt")
    plan_path.write_bytes(serialized_engine)
    engine_bytes = serialized_engine                      # keep for smoke-test
else:                                                     # TRT ≀ 9
    engine = builder.build_engine(network, config)
    assert engine is not None, "build_engine() returned None"
    plan_path = Path("checkpoints/vae_decoder_int8.trt")
    plan_path.write_bytes(engine.serialize())
    engine_bytes = engine.serialize()

print(f"βœ…  TensorRT engine written to {plan_path.resolve()}")

# ─────────────────────────────────────────────────────────
# 4️⃣  Quick smoke-test with the brand-new engine
# ─────────────────────────────────────────────────────────
with trt.Runtime(TRT_LOGGER) as rt:
    engine = rt.deserialize_cuda_engine(engine_bytes)
    context = engine.create_execution_context()
    stream = torch.cuda.current_stream().cuda_stream

    # pre-allocate device buffers once
    device_buffers, outputs = {}, []
    dtype_map = {trt.float32: torch.float32,
                 trt.float16: torch.float16,
                 trt.int8:    torch.int8,
                 trt.int32:   torch.int32}

    for name, tensor in zip(all_inputs_names, inputs):
        if -1 in engine.get_tensor_shape(name):            # dynamic input
            context.set_input_shape(name, tensor.shape)
        context.set_tensor_address(name, int(tensor.data_ptr()))
        device_buffers[name] = tensor

    context.infer_shapes()                                 # propagate β‡’ outputs
    for i in range(engine.num_io_tensors):
        name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
            shape = tuple(context.get_tensor_shape(name))
            dtype = dtype_map[engine.get_tensor_dtype(name)]
            out = torch.empty(shape, dtype=dtype, device="cuda")
            context.set_tensor_address(name, int(out.data_ptr()))
            outputs.append(out)
            print(f"output {name} shape: {shape}")

    context.execute_async_v3(stream_handle=stream)
    torch.cuda.current_stream().synchronize()
    print("βœ…  TRT execution OK – first output shape:", outputs[0].shape)