self-forcing / demo_utils /vae_torch2trt.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
raw
history blame
13.8 kB
# ---- INT8 (optional) ----
from demo_utils.vae import (
VAEDecoderWrapperSingle, # main nn.Module
ZERO_VAE_CACHE # helper constants shipped with your code base
)
import pycuda.driver as cuda # ← add
import pycuda.autoinit # noqa
import sys
from pathlib import Path
import torch
import tensorrt as trt
from utils.dataset import ShardingLMDBDataset
data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=0
)
# ─────────────────────────────────────────────────────────
# 1️⃣ Bring the PyTorch model into scope
# (all code you pasted lives in `vae_decoder.py`)
# ─────────────────────────────────────────────────────────
# --- dummy tensors (exact shapes you posted) ---
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
dummy_cache_input = [
torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
for s in ZERO_VAE_CACHE # keep exactly the same ordering
]
inputs = [dummy_input, is_first_frame, *dummy_cache_input]
# ─────────────────────────────────────────────────────────
# 2️⃣ Export β†’ ONNX
# ─────────────────────────────────────────────────────────
model = VAEDecoderWrapperSingle().half().cuda().eval()
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
if 'decoder.' in key or 'conv2' in key:
decoder_state_dict[key] = value
model.load_state_dict(decoder_state_dict)
model = model.half().cuda().eval() # only batch dim dynamic
onnx_path = Path("vae_decoder.onnx")
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
all_inputs_names = ["z", "use_cache"] + feat_names
with torch.inference_mode():
torch.onnx.export(
model,
tuple(inputs), # must be a tuple
onnx_path.as_posix(),
input_names=all_inputs_names,
output_names=["rgb_out", "cache_out"],
opset_version=17,
do_constant_folding=True,
dynamo=True
)
print(f"βœ… ONNX graph saved to {onnx_path.resolve()}")
# (Optional) quick sanity-check with ONNX-Runtime
try:
import onnxruntime as ort
sess = ort.InferenceSession(onnx_path.as_posix(),
providers=["CUDAExecutionProvider"])
ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
_ = sess.run(None, ort_inputs)
print("βœ… ONNX graph is executable")
except Exception as e:
print("⚠️ ONNX check failed:", e)
# ─────────────────────────────────────────────────────────
# 3️⃣ Build the TensorRT engine
# ─────────────────────────────────────────────────────────
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
sys.exit("❌ ONNX β†’ TRT parsing failed")
config = builder.create_builder_config()
def set_workspace(config, bytes_):
"""Version-agnostic workspace limit."""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
# …
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# 4 GB
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# ---- INT8 (optional) ----
# provide a calibrator if you need an INT8 engine; comment this
# block if you only care about FP16.
# ─────────────────────────────────────────────────────────
# helper: version-agnostic workspace limit
# ─────────────────────────────────────────────────────────
def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
"""
TRT < 10.x β†’ config.max_workspace_size
TRT β‰₯ 10.x β†’ config.set_memory_pool_limit(...)
"""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
bytes_)
# ─────────────────────────────────────────────────────────
# (optional) INT-8 calibrator
# ─────────────────────────────────────────────────────────
# β€Ό Only keep this block if you really need INT-8 β€Ό # gracefully skip if PyCUDA not present
class VAECalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, loader, cache="calibration.cache", max_batches=10):
super().__init__()
self.loader = iter(loader)
self.batch_size = loader.batch_size or 1
self.max_batches = max_batches
self.count = 0
self.cache_file = cache
self.stream = cuda.Stream()
self.dev_ptrs = {}
# --- TRT 10 needs BOTH spellings ---
def get_batch_size(self):
return self.batch_size
def getBatchSize(self):
return self.batch_size
def get_batch(self, names):
if self.count >= self.max_batches:
return None
# Randomly sample a number from 1 to 10
import random
vae_idx = random.randint(0, 10)
data = next(self.loader)
latent = data['ode_latent'][0][:, :1]
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
feat_cache = ZERO_VAE_CACHE
for i in range(vae_idx):
inputs = [latent, is_first_frame, *feat_cache]
with torch.inference_mode():
outputs = model(*inputs)
latent = data['ode_latent'][0][:, i + 1:i + 2]
is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
feat_cache = outputs[1:]
# -------- ensure context is current --------
z_np = latent.cpu().numpy().astype('float32')
ptrs = [] # list[int] – one entry per name
for name in names: # <-- match TRT's binding order
if name == "z":
arr = z_np
elif name == "use_cache":
arr = is_first_frame.cpu().numpy().astype('float32')
else:
idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
arr = feat_cache[idx].cpu().numpy().astype('float32')
if name not in self.dev_ptrs:
self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
self.stream.synchronize()
self.count += 1
print(f"Calibration batch {self.count}/{self.max_batches}")
return ptrs
# --- calibration-cache helpers (both spellings) ---
def read_calibration_cache(self):
try:
with open(self.cache_file, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def readCalibrationCache(self):
return self.read_calibration_cache()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)
def writeCalibrationCache(self, cache):
self.write_calibration_cache(cache)
# ─────────────────────────────────────────────────────────
# Builder-config + optimisation profile
# ─────────────────────────────────────────────────────────
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# β–Ί enable FP16 if possible
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# β–Ί enable INT-8 (delete this block if you don’t need it)
if cuda is not None:
config.set_flag(trt.BuilderFlag.INT8)
# supply any representative batch you like – here we reuse the latent z
calib = VAECalibrator(dataloader)
# TRT-10 renamed the setter:
if hasattr(config, "set_int8_calibrator"): # TRT 10+
config.set_int8_calibrator(calib)
else: # TRT ≀ 9
config.int8_calibrator = calib
# ---- optimisation profile ----
profile = builder.create_optimization_profile()
profile.set_shape(all_inputs_names[0], # latent z
min=(1, 1, 16, 60, 104),
opt=(1, 1, 16, 60, 104),
max=(1, 1, 16, 60, 104))
profile.set_shape("use_cache", # scalar flag
min=(1,), opt=(1,), max=(1,))
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
config.add_optimization_profile(profile)
# ─────────────────────────────────────────────────────────
# Build the engine (API changed in TRT-10)
# ─────────────────────────────────────────────────────────
print("βš™οΈ Building engine … (can take a minute)")
if hasattr(builder, "build_serialized_network"): # TRT 10+
serialized_engine = builder.build_serialized_network(network, config)
assert serialized_engine is not None, "build_serialized_network() failed"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(serialized_engine)
engine_bytes = serialized_engine # keep for smoke-test
else: # TRT ≀ 9
engine = builder.build_engine(network, config)
assert engine is not None, "build_engine() returned None"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(engine.serialize())
engine_bytes = engine.serialize()
print(f"βœ… TensorRT engine written to {plan_path.resolve()}")
# ─────────────────────────────────────────────────────────
# 4️⃣ Quick smoke-test with the brand-new engine
# ─────────────────────────────────────────────────────────
with trt.Runtime(TRT_LOGGER) as rt:
engine = rt.deserialize_cuda_engine(engine_bytes)
context = engine.create_execution_context()
stream = torch.cuda.current_stream().cuda_stream
# pre-allocate device buffers once
device_buffers, outputs = {}, []
dtype_map = {trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int8: torch.int8,
trt.int32: torch.int32}
for name, tensor in zip(all_inputs_names, inputs):
if -1 in engine.get_tensor_shape(name): # dynamic input
context.set_input_shape(name, tensor.shape)
context.set_tensor_address(name, int(tensor.data_ptr()))
device_buffers[name] = tensor
context.infer_shapes() # propagate β‡’ outputs
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = tuple(context.get_tensor_shape(name))
dtype = dtype_map[engine.get_tensor_dtype(name)]
out = torch.empty(shape, dtype=dtype, device="cuda")
context.set_tensor_address(name, int(out.data_ptr()))
outputs.append(out)
print(f"output {name} shape: {shape}")
context.execute_async_v3(stream_handle=stream)
torch.cuda.current_stream().synchronize()
print("βœ… TRT execution OK – first output shape:", outputs[0].shape)