sjtu-deepvision's picture
Upload 9 files
311419e verified
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple, Union
from diffusers import AutoencoderKL
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, Decoder
from diffusers.models.attention_processor import Attention, AttentionProcessor
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.unets.unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
from diffusers.utils.accelerate_utils import apply_forward_hook
class ZeroConv2d(nn.Module):
"""
Zero Convolution layer, similar to the one used in ControlNet.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv.weight.data.zero_()
self.conv.bias.data.zero_()
def forward(self, x):
return self.conv(x)
class CustomAutoencoderKL(AutoencoderKL):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
latent_channels=latent_channels,
norm_num_groups=norm_num_groups,
sample_size=sample_size,
scaling_factor=scaling_factor,
force_upcast=force_upcast,
use_quant_conv=use_quant_conv,
use_post_quant_conv=use_post_quant_conv,
mid_block_add_attention=mid_block_add_attention,
)
# Add Zero Convolution layers to the encoder
# self.zero_convs = nn.ModuleList()
# for i, out_channels_ in enumerate(block_out_channels):
# self.zero_convs.append(ZeroConv2d(out_channels_, out_channels_))
# Modify the decoder to accept skip connections
self.decoder = CustomDecoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.encoder = CustomEncoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
def encode(self, x: torch.Tensor, return_dict: bool = True):
# Get the encoder outputs
_, skip_connections = self.encoder(x)
return skip_connections
def decode(self, z: torch.Tensor, skip_connections: list, return_dict: bool = True):
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
# Decode the latent representation with skip connections
dec = self.decoder(z, skip_connections)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
):
# Encode the input and get the skip connections
posterior, skip_connections = self.encode(sample, return_dict=True)
# Sample from the posterior
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
# Decode the latent representation with skip connections
dec = self.decode(z, skip_connections, return_dict=return_dict)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
class CustomDecoder(Decoder):
def __init__(
self,
in_channels: int,
out_channels: int,
up_block_types: Tuple[str, ...],
block_out_channels: Tuple[int, ...],
layers_per_block: int,
norm_num_groups: int,
act_fn: str,
mid_block_add_attention: bool,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
def forward(
self,
sample: torch.Tensor,
skip_connections: list,
latent_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
# for up_block in self.up_blocks:
# sample = up_block(sample, latent_embeds)
for i, up_block in enumerate(self.up_blocks):
# Add skip connections directly
if i < len(skip_connections):
skip_connection = skip_connections[-(i + 1)]
# import pdb; pdb.set_trace()
sample = sample + skip_connection
# import pdb; pdb.set_trace() #torch.Size([1, 512, 96, 96]
sample = up_block(sample)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class CustomEncoder(Encoder):
r"""
Custom Encoder that adds Zero Convolution layers to each block's output
to generate skip connections.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention: bool = True,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=double_z,
mid_block_add_attention=mid_block_add_attention,
)
# Add Zero Convolution layers to each block's output
self.zero_convs = nn.ModuleList()
for i, out_channels in enumerate(block_out_channels):
if i < 2:
self.zero_convs.append(ZeroConv2d(out_channels, out_channels * 2))
else:
self.zero_convs.append(ZeroConv2d(out_channels, out_channels))
def forward(self, sample: torch.Tensor) -> list[torch.Tensor]:
r"""
Forward pass of the CustomEncoder.
Args:
sample (`torch.Tensor`): Input tensor.
Returns:
`Tuple[torch.Tensor, List[torch.Tensor]]`:
- The final latent representation.
- A list of skip connections from each block.
"""
skip_connections = []
# Initial convolution
sample = self.conv_in(sample)
# Down blocks
for i, (down_block, zero_conv) in enumerate(zip(self.down_blocks, self.zero_convs)):
# import pdb; pdb.set_trace()
sample = down_block(sample)
if i != len(self.down_blocks) - 1:
sample_out = nn.functional.interpolate(zero_conv(sample), scale_factor=2, mode='bilinear', align_corners=False)
else:
sample_out = zero_conv(sample)
skip_connections.append(sample_out)
# import pdb; pdb.set_trace()
# torch.Size([1, 128, 768, 768])
# torch.Size([1, 128, 384, 384])
# torch.Size([1, 256, 192, 192])
# torch.Size([1, 512, 96, 96])
# torch.Size([1, 512, 96, 96])
# # Middle block
# sample = self.mid_block(sample)
# # Post-process
# sample = self.conv_norm_out(sample)
# sample = self.conv_act(sample)
# sample = self.conv_out(sample)
return sample, skip_connections