Spaces:
Runtime error
Runtime error
VAE: Support more configurations for Encoder and Decoder blocks
Browse filesVAE: Define encoder compress-all block with channel multiplier
VAE: Support residual connection in the decoder
VAE: Refactor CausalConv3d parameters
lint
xora/models/autoencoders/causal_conv3d.py
CHANGED
|
@@ -11,6 +11,8 @@ class CausalConv3d(nn.Module):
|
|
| 11 |
out_channels,
|
| 12 |
kernel_size: int = 3,
|
| 13 |
stride: Union[int, Tuple[int]] = 1,
|
|
|
|
|
|
|
| 14 |
**kwargs,
|
| 15 |
):
|
| 16 |
super().__init__()
|
|
@@ -21,7 +23,6 @@ class CausalConv3d(nn.Module):
|
|
| 21 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 22 |
self.time_kernel_size = kernel_size[0]
|
| 23 |
|
| 24 |
-
dilation = kwargs.pop("dilation", 1)
|
| 25 |
dilation = (dilation, 1, 1)
|
| 26 |
|
| 27 |
height_pad = kernel_size[1] // 2
|
|
@@ -36,6 +37,7 @@ class CausalConv3d(nn.Module):
|
|
| 36 |
dilation=dilation,
|
| 37 |
padding=padding,
|
| 38 |
padding_mode="zeros",
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
def forward(self, x, causal: bool = True):
|
|
|
|
| 11 |
out_channels,
|
| 12 |
kernel_size: int = 3,
|
| 13 |
stride: Union[int, Tuple[int]] = 1,
|
| 14 |
+
dilation: int = 1,
|
| 15 |
+
groups: int = 1,
|
| 16 |
**kwargs,
|
| 17 |
):
|
| 18 |
super().__init__()
|
|
|
|
| 23 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 24 |
self.time_kernel_size = kernel_size[0]
|
| 25 |
|
|
|
|
| 26 |
dilation = (dilation, 1, 1)
|
| 27 |
|
| 28 |
height_pad = kernel_size[1] // 2
|
|
|
|
| 37 |
dilation=dilation,
|
| 38 |
padding=padding,
|
| 39 |
padding_mode="zeros",
|
| 40 |
+
groups=groups,
|
| 41 |
)
|
| 42 |
|
| 43 |
def forward(self, x, causal: bool = True):
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -78,7 +78,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 78 |
dims=config["dims"],
|
| 79 |
in_channels=config.get("in_channels", 3),
|
| 80 |
out_channels=config["latent_channels"],
|
| 81 |
-
blocks=config
|
| 82 |
patch_size=config.get("patch_size", 1),
|
| 83 |
latent_log_var=latent_log_var,
|
| 84 |
norm_layer=config.get("norm_layer", "group_norm"),
|
|
@@ -88,7 +88,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 88 |
dims=config["dims"],
|
| 89 |
in_channels=config["latent_channels"],
|
| 90 |
out_channels=config.get("out_channels", 3),
|
| 91 |
-
blocks=config
|
| 92 |
patch_size=config.get("patch_size", 1),
|
| 93 |
norm_layer=config.get("norm_layer", "group_norm"),
|
| 94 |
causal=config.get("causal_decoder", False),
|
|
@@ -112,7 +112,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 112 |
out_channels=self.decoder.conv_out.out_channels
|
| 113 |
// self.decoder.patch_size**2,
|
| 114 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 115 |
-
|
|
|
|
| 116 |
scaling_factor=1.0,
|
| 117 |
norm_layer=self.encoder.norm_layer,
|
| 118 |
patch_size=self.encoder.patch_size,
|
|
@@ -242,7 +243,7 @@ class Encoder(nn.Module):
|
|
| 242 |
dims: Union[int, Tuple[int, int]] = 3,
|
| 243 |
in_channels: int = 3,
|
| 244 |
out_channels: int = 3,
|
| 245 |
-
blocks: List[Tuple[str, int]] = [("res_x", 1)],
|
| 246 |
base_channels: int = 128,
|
| 247 |
norm_num_groups: int = 32,
|
| 248 |
patch_size: Union[int, Tuple[int]] = 1,
|
|
@@ -271,20 +272,22 @@ class Encoder(nn.Module):
|
|
| 271 |
|
| 272 |
self.down_blocks = nn.ModuleList([])
|
| 273 |
|
| 274 |
-
for block_name,
|
| 275 |
input_channel = output_channel
|
|
|
|
|
|
|
| 276 |
|
| 277 |
if block_name == "res_x":
|
| 278 |
block = UNetMidBlock3D(
|
| 279 |
dims=dims,
|
| 280 |
in_channels=input_channel,
|
| 281 |
-
num_layers=num_layers,
|
| 282 |
resnet_eps=1e-6,
|
| 283 |
resnet_groups=norm_num_groups,
|
| 284 |
norm_layer=norm_layer,
|
| 285 |
)
|
| 286 |
elif block_name == "res_x_y":
|
| 287 |
-
output_channel = 2 * output_channel
|
| 288 |
block = ResnetBlock3D(
|
| 289 |
dims=dims,
|
| 290 |
in_channels=input_channel,
|
|
@@ -320,6 +323,16 @@ class Encoder(nn.Module):
|
|
| 320 |
stride=(2, 2, 2),
|
| 321 |
causal=True,
|
| 322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
else:
|
| 324 |
raise ValueError(f"unknown block: {block_name}")
|
| 325 |
|
|
@@ -421,7 +434,7 @@ class Decoder(nn.Module):
|
|
| 421 |
dims,
|
| 422 |
in_channels: int = 3,
|
| 423 |
out_channels: int = 3,
|
| 424 |
-
blocks: List[Tuple[str, int]] = [("res_x", 1)],
|
| 425 |
base_channels: int = 128,
|
| 426 |
layers_per_block: int = 2,
|
| 427 |
norm_num_groups: int = 32,
|
|
@@ -433,9 +446,15 @@ class Decoder(nn.Module):
|
|
| 433 |
self.patch_size = patch_size
|
| 434 |
self.layers_per_block = layers_per_block
|
| 435 |
out_channels = out_channels * patch_size**2
|
| 436 |
-
num_channel_doubles = len([x for x in blocks if x[0] == "res_x_y"])
|
| 437 |
-
output_channel = base_channels * 2**num_channel_doubles
|
| 438 |
self.causal = causal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
self.conv_in = make_conv_nd(
|
| 441 |
dims,
|
|
@@ -449,20 +468,22 @@ class Decoder(nn.Module):
|
|
| 449 |
|
| 450 |
self.up_blocks = nn.ModuleList([])
|
| 451 |
|
| 452 |
-
for block_name,
|
| 453 |
input_channel = output_channel
|
|
|
|
|
|
|
| 454 |
|
| 455 |
if block_name == "res_x":
|
| 456 |
block = UNetMidBlock3D(
|
| 457 |
dims=dims,
|
| 458 |
in_channels=input_channel,
|
| 459 |
-
num_layers=num_layers,
|
| 460 |
resnet_eps=1e-6,
|
| 461 |
resnet_groups=norm_num_groups,
|
| 462 |
norm_layer=norm_layer,
|
| 463 |
)
|
| 464 |
elif block_name == "res_x_y":
|
| 465 |
-
output_channel = output_channel // 2
|
| 466 |
block = ResnetBlock3D(
|
| 467 |
dims=dims,
|
| 468 |
in_channels=input_channel,
|
|
@@ -481,7 +502,10 @@ class Decoder(nn.Module):
|
|
| 481 |
)
|
| 482 |
elif block_name == "compress_all":
|
| 483 |
block = DepthToSpaceUpsample(
|
| 484 |
-
dims=dims,
|
|
|
|
|
|
|
|
|
|
| 485 |
)
|
| 486 |
else:
|
| 487 |
raise ValueError(f"unknown layer: {block_name}")
|
|
@@ -590,7 +614,7 @@ class UNetMidBlock3D(nn.Module):
|
|
| 590 |
|
| 591 |
|
| 592 |
class DepthToSpaceUpsample(nn.Module):
|
| 593 |
-
def __init__(self, dims, in_channels, stride):
|
| 594 |
super().__init__()
|
| 595 |
self.stride = stride
|
| 596 |
self.out_channels = np.prod(stride) * in_channels
|
|
@@ -602,8 +626,21 @@ class DepthToSpaceUpsample(nn.Module):
|
|
| 602 |
stride=1,
|
| 603 |
causal=True,
|
| 604 |
)
|
|
|
|
| 605 |
|
| 606 |
def forward(self, x, causal: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
x = self.conv(x, causal=causal)
|
| 608 |
x = rearrange(
|
| 609 |
x,
|
|
@@ -614,6 +651,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|
| 614 |
)
|
| 615 |
if self.stride[0] == 2:
|
| 616 |
x = x[:, :, 1:, :, :]
|
|
|
|
|
|
|
| 617 |
return x
|
| 618 |
|
| 619 |
|
|
@@ -647,7 +686,6 @@ class ResnetBlock3D(nn.Module):
|
|
| 647 |
dims: Union[int, Tuple[int, int]],
|
| 648 |
in_channels: int,
|
| 649 |
out_channels: Optional[int] = None,
|
| 650 |
-
conv_shortcut: bool = False,
|
| 651 |
dropout: float = 0.0,
|
| 652 |
groups: int = 32,
|
| 653 |
eps: float = 1e-6,
|
|
@@ -657,7 +695,6 @@ class ResnetBlock3D(nn.Module):
|
|
| 657 |
self.in_channels = in_channels
|
| 658 |
out_channels = in_channels if out_channels is None else out_channels
|
| 659 |
self.out_channels = out_channels
|
| 660 |
-
self.use_conv_shortcut = conv_shortcut
|
| 661 |
|
| 662 |
if norm_layer == "group_norm":
|
| 663 |
self.norm1 = nn.GroupNorm(
|
|
|
|
| 78 |
dims=config["dims"],
|
| 79 |
in_channels=config.get("in_channels", 3),
|
| 80 |
out_channels=config["latent_channels"],
|
| 81 |
+
blocks=config.get("encoder_blocks", config.get("blocks")),
|
| 82 |
patch_size=config.get("patch_size", 1),
|
| 83 |
latent_log_var=latent_log_var,
|
| 84 |
norm_layer=config.get("norm_layer", "group_norm"),
|
|
|
|
| 88 |
dims=config["dims"],
|
| 89 |
in_channels=config["latent_channels"],
|
| 90 |
out_channels=config.get("out_channels", 3),
|
| 91 |
+
blocks=config.get("decoder_blocks", config.get("blocks")),
|
| 92 |
patch_size=config.get("patch_size", 1),
|
| 93 |
norm_layer=config.get("norm_layer", "group_norm"),
|
| 94 |
causal=config.get("causal_decoder", False),
|
|
|
|
| 112 |
out_channels=self.decoder.conv_out.out_channels
|
| 113 |
// self.decoder.patch_size**2,
|
| 114 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 115 |
+
encoder_blocks=self.encoder.blocks_desc,
|
| 116 |
+
decoder_blocks=self.decoder.blocks_desc,
|
| 117 |
scaling_factor=1.0,
|
| 118 |
norm_layer=self.encoder.norm_layer,
|
| 119 |
patch_size=self.encoder.patch_size,
|
|
|
|
| 243 |
dims: Union[int, Tuple[int, int]] = 3,
|
| 244 |
in_channels: int = 3,
|
| 245 |
out_channels: int = 3,
|
| 246 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 247 |
base_channels: int = 128,
|
| 248 |
norm_num_groups: int = 32,
|
| 249 |
patch_size: Union[int, Tuple[int]] = 1,
|
|
|
|
| 272 |
|
| 273 |
self.down_blocks = nn.ModuleList([])
|
| 274 |
|
| 275 |
+
for block_name, block_params in blocks:
|
| 276 |
input_channel = output_channel
|
| 277 |
+
if isinstance(block_params, int):
|
| 278 |
+
block_params = {"num_layers": block_params}
|
| 279 |
|
| 280 |
if block_name == "res_x":
|
| 281 |
block = UNetMidBlock3D(
|
| 282 |
dims=dims,
|
| 283 |
in_channels=input_channel,
|
| 284 |
+
num_layers=block_params["num_layers"],
|
| 285 |
resnet_eps=1e-6,
|
| 286 |
resnet_groups=norm_num_groups,
|
| 287 |
norm_layer=norm_layer,
|
| 288 |
)
|
| 289 |
elif block_name == "res_x_y":
|
| 290 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 291 |
block = ResnetBlock3D(
|
| 292 |
dims=dims,
|
| 293 |
in_channels=input_channel,
|
|
|
|
| 323 |
stride=(2, 2, 2),
|
| 324 |
causal=True,
|
| 325 |
)
|
| 326 |
+
elif block_name == "compress_all_x_y":
|
| 327 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 328 |
+
block = make_conv_nd(
|
| 329 |
+
dims=dims,
|
| 330 |
+
in_channels=input_channel,
|
| 331 |
+
out_channels=output_channel,
|
| 332 |
+
kernel_size=3,
|
| 333 |
+
stride=(2, 2, 2),
|
| 334 |
+
causal=True,
|
| 335 |
+
)
|
| 336 |
else:
|
| 337 |
raise ValueError(f"unknown block: {block_name}")
|
| 338 |
|
|
|
|
| 434 |
dims,
|
| 435 |
in_channels: int = 3,
|
| 436 |
out_channels: int = 3,
|
| 437 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 438 |
base_channels: int = 128,
|
| 439 |
layers_per_block: int = 2,
|
| 440 |
norm_num_groups: int = 32,
|
|
|
|
| 446 |
self.patch_size = patch_size
|
| 447 |
self.layers_per_block = layers_per_block
|
| 448 |
out_channels = out_channels * patch_size**2
|
|
|
|
|
|
|
| 449 |
self.causal = causal
|
| 450 |
+
self.blocks_desc = blocks
|
| 451 |
+
|
| 452 |
+
# Compute output channel to be product of all channel-multiplier blocks
|
| 453 |
+
output_channel = base_channels
|
| 454 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 455 |
+
block_params = block_params if isinstance(block_params, dict) else {}
|
| 456 |
+
if block_name == "res_x_y":
|
| 457 |
+
output_channel = output_channel * block_params.get("multiplier", 2)
|
| 458 |
|
| 459 |
self.conv_in = make_conv_nd(
|
| 460 |
dims,
|
|
|
|
| 468 |
|
| 469 |
self.up_blocks = nn.ModuleList([])
|
| 470 |
|
| 471 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 472 |
input_channel = output_channel
|
| 473 |
+
if isinstance(block_params, int):
|
| 474 |
+
block_params = {"num_layers": block_params}
|
| 475 |
|
| 476 |
if block_name == "res_x":
|
| 477 |
block = UNetMidBlock3D(
|
| 478 |
dims=dims,
|
| 479 |
in_channels=input_channel,
|
| 480 |
+
num_layers=block_params["num_layers"],
|
| 481 |
resnet_eps=1e-6,
|
| 482 |
resnet_groups=norm_num_groups,
|
| 483 |
norm_layer=norm_layer,
|
| 484 |
)
|
| 485 |
elif block_name == "res_x_y":
|
| 486 |
+
output_channel = output_channel // block_params.get("multiplier", 2)
|
| 487 |
block = ResnetBlock3D(
|
| 488 |
dims=dims,
|
| 489 |
in_channels=input_channel,
|
|
|
|
| 502 |
)
|
| 503 |
elif block_name == "compress_all":
|
| 504 |
block = DepthToSpaceUpsample(
|
| 505 |
+
dims=dims,
|
| 506 |
+
in_channels=input_channel,
|
| 507 |
+
stride=(2, 2, 2),
|
| 508 |
+
residual=block_params.get("residual", False),
|
| 509 |
)
|
| 510 |
else:
|
| 511 |
raise ValueError(f"unknown layer: {block_name}")
|
|
|
|
| 614 |
|
| 615 |
|
| 616 |
class DepthToSpaceUpsample(nn.Module):
|
| 617 |
+
def __init__(self, dims, in_channels, stride, residual=False):
|
| 618 |
super().__init__()
|
| 619 |
self.stride = stride
|
| 620 |
self.out_channels = np.prod(stride) * in_channels
|
|
|
|
| 626 |
stride=1,
|
| 627 |
causal=True,
|
| 628 |
)
|
| 629 |
+
self.residual = residual
|
| 630 |
|
| 631 |
def forward(self, x, causal: bool = True):
|
| 632 |
+
if self.residual:
|
| 633 |
+
# Reshape and duplicate the input to match the output shape
|
| 634 |
+
x_in = rearrange(
|
| 635 |
+
x,
|
| 636 |
+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
| 637 |
+
p1=self.stride[0],
|
| 638 |
+
p2=self.stride[1],
|
| 639 |
+
p3=self.stride[2],
|
| 640 |
+
)
|
| 641 |
+
x_in = x_in.repeat(1, np.prod(self.stride), 1, 1, 1)
|
| 642 |
+
if self.stride[0] == 2:
|
| 643 |
+
x_in = x_in[:, :, 1:, :, :]
|
| 644 |
x = self.conv(x, causal=causal)
|
| 645 |
x = rearrange(
|
| 646 |
x,
|
|
|
|
| 651 |
)
|
| 652 |
if self.stride[0] == 2:
|
| 653 |
x = x[:, :, 1:, :, :]
|
| 654 |
+
if self.residual:
|
| 655 |
+
x = x + x_in
|
| 656 |
return x
|
| 657 |
|
| 658 |
|
|
|
|
| 686 |
dims: Union[int, Tuple[int, int]],
|
| 687 |
in_channels: int,
|
| 688 |
out_channels: Optional[int] = None,
|
|
|
|
| 689 |
dropout: float = 0.0,
|
| 690 |
groups: int = 32,
|
| 691 |
eps: float = 1e-6,
|
|
|
|
| 695 |
self.in_channels = in_channels
|
| 696 |
out_channels = in_channels if out_channels is None else out_channels
|
| 697 |
self.out_channels = out_channels
|
|
|
|
| 698 |
|
| 699 |
if norm_layer == "group_norm":
|
| 700 |
self.norm1 = nn.GroupNorm(
|