Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,157 Bytes
0fd2f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(
checkpoint_path, map_location="cpu", weights_only=True)))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=False):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
# return x[:, self.frames_to_trim:]
return x
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import os
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, 'writer'):
self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
print(
f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
video_out = VideoTensorWriter(
video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()
|