Spaces:
Runtime error
Runtime error
"""transformer based denoiser""" | |
import torch | |
from einops.layers.torch import Rearrange | |
from torch import nn | |
from transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding | |
class DenoiserTransBlock(nn.Module): | |
def __init__( | |
self, | |
patch_size: int, | |
img_size: int, | |
embed_dim: int, | |
dropout: float, | |
n_layers: int, | |
mlp_multiplier: int = 4, | |
n_channels: int = 4, | |
): | |
super().__init__() | |
self.patch_size = patch_size | |
self.img_size = img_size | |
self.n_channels = n_channels | |
self.embed_dim = embed_dim | |
self.dropout = dropout | |
self.n_layers = n_layers | |
self.mlp_multiplier = mlp_multiplier | |
seq_len = int((self.img_size / self.patch_size) * (self.img_size / self.patch_size)) | |
patch_dim = self.n_channels * self.patch_size * self.patch_size | |
self.patchify_and_embed = nn.Sequential( | |
nn.Conv2d( | |
self.n_channels, | |
patch_dim, | |
kernel_size=self.patch_size, | |
stride=self.patch_size, | |
), | |
Rearrange("bs d h w -> bs (h w) d"), | |
nn.LayerNorm(patch_dim), | |
nn.Linear(patch_dim, self.embed_dim), | |
nn.LayerNorm(self.embed_dim), | |
) | |
self.rearrange2 = Rearrange( | |
"b (h w) (c p1 p2) -> b c (h p1) (w p2)", | |
h=int(self.img_size / self.patch_size), | |
p1=self.patch_size, | |
p2=self.patch_size, | |
) | |
self.pos_embed = nn.Embedding(seq_len, self.embed_dim) | |
self.register_buffer("precomputed_pos_enc", torch.arange(0, seq_len).long()) | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
DecoderBlock( | |
embed_dim=self.embed_dim, | |
mlp_multiplier=self.mlp_multiplier, | |
# note that this is a non-causal block since we are | |
# denoising the entire image no need for masking | |
is_causal=False, | |
dropout_level=self.dropout, | |
mlp_class=MLPSepConv, | |
) | |
for _ in range(self.n_layers) | |
] | |
) | |
self.out_proj = nn.Sequential(nn.Linear(self.embed_dim, patch_dim), self.rearrange2) | |
def forward(self, x, cond): | |
x = self.patchify_and_embed(x) | |
pos_enc = self.precomputed_pos_enc[: x.size(1)].expand(x.size(0), -1) | |
x = x + self.pos_embed(pos_enc) | |
for block in self.decoder_blocks: | |
x = block(x, cond) | |
return self.out_proj(x) | |
class Denoiser(nn.Module): | |
def __init__( | |
self, | |
image_size: int, | |
noise_embed_dims: int, | |
patch_size: int, | |
embed_dim: int, | |
dropout: float, | |
n_layers: int, | |
text_emb_size: int = 768, | |
): | |
super().__init__() | |
self.image_size = image_size | |
self.noise_embed_dims = noise_embed_dims | |
self.embed_dim = embed_dim | |
self.fourier_feats = nn.Sequential( | |
SinusoidalEmbedding(embedding_dims=noise_embed_dims), | |
nn.Linear(noise_embed_dims, self.embed_dim), | |
nn.GELU(), | |
nn.Linear(self.embed_dim, self.embed_dim), | |
) | |
self.denoiser_trans_block = DenoiserTransBlock(patch_size, image_size, embed_dim, dropout, n_layers) | |
self.norm = nn.LayerNorm(self.embed_dim) | |
self.label_proj = nn.Linear(text_emb_size, self.embed_dim) | |
def forward(self, x, noise_level, label): | |
noise_level = self.fourier_feats(noise_level).unsqueeze(1) | |
label = self.label_proj(label).unsqueeze(1) | |
noise_label_emb = torch.cat([noise_level, label], dim=1) # bs, 2, d | |
noise_label_emb = self.norm(noise_label_emb) | |
x = self.denoiser_trans_block(x, noise_label_emb) | |
return x | |