flux-latent-preview / tea_model.py
twodgirl's picture
Update comment.
46d646d verified
raw
history blame
2.16 kB
import torch
from torch import nn
# VAE Decoder but with half-size output.
# The last upsample is not there.
###
# Code from madebyollin/taesd
class Recon(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.long = nn.Sequential(
nn.Conv2d(ch_in, ch_out, 3, padding=1),
nn.ReLU(),
nn.Conv2d(ch_out, ch_out, 3, padding=1),
nn.ReLU(),
nn.Conv2d(ch_out, ch_out, 3, padding=1)
)
if ch_in != ch_out:
self.short = nn.Conv2d(ch_in, ch_out, 1, bias=False)
else:
# The one without identity, a placeholder.
self.short = nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.long(x) + self.short(x))
class TeaDecoder(nn.Module):
def __init__(self, ch_in):
super().__init__()
self.block_in = nn.Sequential(
nn.Conv2d(ch_in, 64, 3, padding=1),
nn.ReLU()
)
self.middle = nn.Sequential(
*[Recon(64, 64) for _ in range(3)],
# Opposite of stride=2
nn.Upsample(scale_factor=2),
# It leads to a simpler model with fewer parameters.
# The output of the previous layers matches the number of channels specified in this line.
# The input to this layer is already well-represented by the feature maps from the previous layers,
# the bias may not add significant value.
nn.Conv2d(64, 64, 3, padding=1, bias=False),
# Final upscale to 1/2 size of the image.
*[Recon(64, 64) for _ in range(3)],
nn.Upsample(scale_factor=2),
nn.Conv2d(64, 64, 3, padding=1, bias=False),
)
self.block_out = nn.Sequential(
Recon(64, 64),
# Convert to RGB, regardless of the latent channels.
nn.Conv2d(64, 3, 3, padding=1),
)
def forward(self, x):
# Clamp the input values to a specific range.
clamped = torch.tanh(x / 1)
cooked = self.middle(self.block_in(clamped))
return self.block_out(cooked)