import torch import torch.nn as nn import torch.nn.functional as F from .attention import SelfAttention, CrossAttention class TimeEmbedding(nn.Module): def __init__(self, n_embd): super().__init__() self.linear_1 = nn.Linear(n_embd, 4 * n_embd) self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd) def forward(self, x): x = F.silu(self.linear_1(x)) return self.linear_2(x) class UNET_ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, n_time=1280): super().__init__() self.groupnorm_feature = nn.GroupNorm(32, in_channels) self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.linear_time = nn.Linear(n_time, out_channels) self.groupnorm_merged = nn.GroupNorm(32, out_channels) self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) def forward(self, feature, time): residue = feature feature = F.silu(self.groupnorm_feature(feature)) feature = self.conv_feature(feature) time = self.linear_time(F.silu(time)) merged = feature + time.unsqueeze(-1).unsqueeze(-1) merged = F.silu(self.groupnorm_merged(merged)) merged = self.conv_merged(merged) return merged + self.residual_layer(residue) class UNET_AttentionBlock(nn.Module): def __init__(self, n_head: int, n_embd: int, d_context=768): super().__init__() channels = n_head * n_embd self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6) self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.layernorm_1 = nn.LayerNorm(channels) self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False) self.layernorm_2 = nn.LayerNorm(channels) self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False) self.layernorm_3 = nn.LayerNorm(channels) self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) self.linear_geglu_2 = nn.Linear(4 * channels, channels) self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0) def forward(self, x, context): residue_long = x x = self.conv_input(self.groupnorm(x)) n, c, h, w = x.shape x = x.view((n, c, h * w)).transpose(-1, -2) residue_short = x x = self.attention_1(self.layernorm_1(x)) + residue_short residue_short = x x = self.attention_2(self.layernorm_2(x), context) + residue_short residue_short = x x, gate = self.linear_geglu_1(self.layernorm_3(x)).chunk(2, dim=-1) x = self.linear_geglu_2(x * F.gelu(gate)) + residue_short x = x.transpose(-1, -2).view((n, c, h, w)) return self.conv_output(x) + residue_long class Upsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) class SwitchSequential(nn.Sequential): def forward(self, x, context, time): for layer in self: if isinstance(layer, UNET_AttentionBlock): x = layer(x, context) elif isinstance(layer, UNET_ResidualBlock): x = layer(x, time) else: x = layer(x) return x class UNET(nn.Module): def __init__(self): super().__init__() self.encoders = nn.ModuleList([ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)), SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)), SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)), SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)), SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)), SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)), SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)), SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)), SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)), SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)), SwitchSequential(UNET_ResidualBlock(1280, 1280)), SwitchSequential(UNET_ResidualBlock(1280, 1280)), ]) self.bottleneck = SwitchSequential( UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160), UNET_ResidualBlock(1280, 1280), ) self.decoders = nn.ModuleList([ SwitchSequential(UNET_ResidualBlock(2560, 1280)), SwitchSequential(UNET_ResidualBlock(2560, 1280)), SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)), SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)), SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)), SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)), SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)), SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)), SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)), SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)), SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)), SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)), ]) def forward(self, x, context, time): skip_connections = [] for layers in self.encoders: x = layers(x, context, time) skip_connections.append(x) x = self.bottleneck(x, context, time) for layers in self.decoders: x = torch.cat((x, skip_connections.pop()), dim=1) x = layers(x, context, time) return x class UNET_OutputLayer(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.groupnorm = nn.GroupNorm(32, in_channels) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, x): x = F.silu(self.groupnorm(x)) return self.conv(x) class Diffusion(nn.Module): def __init__(self): super().__init__() self.time_embedding = TimeEmbedding(320) self.unet = UNET() self.final = UNET_OutputLayer(320, 4) def forward(self, latent, context, time): time = self.time_embedding(time) output = self.unet(latent, context, time) return self.final(output)