Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from typing import Dict, Optional, Tuple, Union | |
from diffusers import AutoencoderKL | |
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, Decoder | |
from diffusers.models.attention_processor import Attention, AttentionProcessor | |
from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
from diffusers.models.unets.unet_2d_blocks import ( | |
AutoencoderTinyBlock, | |
UNetMidBlock2D, | |
get_down_block, | |
get_up_block, | |
) | |
from diffusers.utils.accelerate_utils import apply_forward_hook | |
class ZeroConv2d(nn.Module): | |
""" | |
Zero Convolution layer, similar to the one used in ControlNet. | |
""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
self.conv.weight.data.zero_() | |
self.conv.bias.data.zero_() | |
def forward(self, x): | |
return self.conv(x) | |
class CustomAutoencoderKL(AutoencoderKL): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), | |
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), | |
block_out_channels: Tuple[int] = (64,), | |
layers_per_block: int = 1, | |
act_fn: str = "silu", | |
latent_channels: int = 4, | |
norm_num_groups: int = 32, | |
sample_size: int = 32, | |
scaling_factor: float = 0.18215, | |
force_upcast: float = True, | |
use_quant_conv: bool = True, | |
use_post_quant_conv: bool = True, | |
mid_block_add_attention: bool = True, | |
): | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
down_block_types=down_block_types, | |
up_block_types=up_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
act_fn=act_fn, | |
latent_channels=latent_channels, | |
norm_num_groups=norm_num_groups, | |
sample_size=sample_size, | |
scaling_factor=scaling_factor, | |
force_upcast=force_upcast, | |
use_quant_conv=use_quant_conv, | |
use_post_quant_conv=use_post_quant_conv, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
# Add Zero Convolution layers to the encoder | |
# self.zero_convs = nn.ModuleList() | |
# for i, out_channels_ in enumerate(block_out_channels): | |
# self.zero_convs.append(ZeroConv2d(out_channels_, out_channels_)) | |
# Modify the decoder to accept skip connections | |
self.decoder = CustomDecoder( | |
in_channels=latent_channels, | |
out_channels=out_channels, | |
up_block_types=up_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
norm_num_groups=norm_num_groups, | |
act_fn=act_fn, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
self.encoder = CustomEncoder( | |
in_channels=in_channels, | |
out_channels=latent_channels, | |
down_block_types=down_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
norm_num_groups=norm_num_groups, | |
act_fn=act_fn, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
def encode(self, x: torch.Tensor, return_dict: bool = True): | |
# Get the encoder outputs | |
_, skip_connections = self.encoder(x) | |
return skip_connections | |
def decode(self, z: torch.Tensor, skip_connections: list, return_dict: bool = True): | |
if self.post_quant_conv is not None: | |
z = self.post_quant_conv(z) | |
# Decode the latent representation with skip connections | |
dec = self.decoder(z, skip_connections) | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
def forward( | |
self, | |
sample: torch.Tensor, | |
sample_posterior: bool = False, | |
return_dict: bool = True, | |
generator: Optional[torch.Generator] = None, | |
): | |
# Encode the input and get the skip connections | |
posterior, skip_connections = self.encode(sample, return_dict=True) | |
# Sample from the posterior | |
if sample_posterior: | |
z = posterior.sample(generator=generator) | |
else: | |
z = posterior.mode() | |
# Decode the latent representation with skip connections | |
dec = self.decode(z, skip_connections, return_dict=return_dict) | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
class CustomDecoder(Decoder): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
up_block_types: Tuple[str, ...], | |
block_out_channels: Tuple[int, ...], | |
layers_per_block: int, | |
norm_num_groups: int, | |
act_fn: str, | |
mid_block_add_attention: bool, | |
): | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
up_block_types=up_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
norm_num_groups=norm_num_groups, | |
act_fn=act_fn, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
def forward( | |
self, | |
sample: torch.Tensor, | |
skip_connections: list, | |
latent_embeds: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
r"""The forward method of the `Decoder` class.""" | |
sample = self.conv_in(sample) | |
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if is_torch_version(">=", "1.11.0"): | |
# middle | |
sample = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.mid_block), | |
sample, | |
latent_embeds, | |
use_reentrant=False, | |
) | |
sample = sample.to(upscale_dtype) | |
# up | |
for up_block in self.up_blocks: | |
sample = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(up_block), | |
sample, | |
latent_embeds, | |
use_reentrant=False, | |
) | |
else: | |
# middle | |
sample = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.mid_block), sample, latent_embeds | |
) | |
sample = sample.to(upscale_dtype) | |
# up | |
for up_block in self.up_blocks: | |
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) | |
else: | |
# middle | |
sample = self.mid_block(sample, latent_embeds) | |
sample = sample.to(upscale_dtype) | |
# up | |
# for up_block in self.up_blocks: | |
# sample = up_block(sample, latent_embeds) | |
for i, up_block in enumerate(self.up_blocks): | |
# Add skip connections directly | |
if i < len(skip_connections): | |
skip_connection = skip_connections[-(i + 1)] | |
# import pdb; pdb.set_trace() | |
sample = sample + skip_connection | |
# import pdb; pdb.set_trace() #torch.Size([1, 512, 96, 96] | |
sample = up_block(sample) | |
# post-process | |
if latent_embeds is None: | |
sample = self.conv_norm_out(sample) | |
else: | |
sample = self.conv_norm_out(sample, latent_embeds) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
return sample | |
class CustomEncoder(Encoder): | |
r""" | |
Custom Encoder that adds Zero Convolution layers to each block's output | |
to generate skip connections. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), | |
block_out_channels: Tuple[int, ...] = (64,), | |
layers_per_block: int = 2, | |
norm_num_groups: int = 32, | |
act_fn: str = "silu", | |
double_z: bool = True, | |
mid_block_add_attention: bool = True, | |
): | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
down_block_types=down_block_types, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
norm_num_groups=norm_num_groups, | |
act_fn=act_fn, | |
double_z=double_z, | |
mid_block_add_attention=mid_block_add_attention, | |
) | |
# Add Zero Convolution layers to each block's output | |
self.zero_convs = nn.ModuleList() | |
for i, out_channels in enumerate(block_out_channels): | |
if i < 2: | |
self.zero_convs.append(ZeroConv2d(out_channels, out_channels * 2)) | |
else: | |
self.zero_convs.append(ZeroConv2d(out_channels, out_channels)) | |
def forward(self, sample: torch.Tensor) -> list[torch.Tensor]: | |
r""" | |
Forward pass of the CustomEncoder. | |
Args: | |
sample (`torch.Tensor`): Input tensor. | |
Returns: | |
`Tuple[torch.Tensor, List[torch.Tensor]]`: | |
- The final latent representation. | |
- A list of skip connections from each block. | |
""" | |
skip_connections = [] | |
# Initial convolution | |
sample = self.conv_in(sample) | |
# Down blocks | |
for i, (down_block, zero_conv) in enumerate(zip(self.down_blocks, self.zero_convs)): | |
# import pdb; pdb.set_trace() | |
sample = down_block(sample) | |
if i != len(self.down_blocks) - 1: | |
sample_out = nn.functional.interpolate(zero_conv(sample), scale_factor=2, mode='bilinear', align_corners=False) | |
else: | |
sample_out = zero_conv(sample) | |
skip_connections.append(sample_out) | |
# import pdb; pdb.set_trace() | |
# torch.Size([1, 128, 768, 768]) | |
# torch.Size([1, 128, 384, 384]) | |
# torch.Size([1, 256, 192, 192]) | |
# torch.Size([1, 512, 96, 96]) | |
# torch.Size([1, 512, 96, 96]) | |
# # Middle block | |
# sample = self.mid_block(sample) | |
# # Post-process | |
# sample = self.conv_norm_out(sample) | |
# sample = self.conv_act(sample) | |
# sample = self.conv_out(sample) | |
return sample, skip_connections |