|
from enum import Enum |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn.functional import silu |
|
|
|
from .DiffAE_model_latentnet import * |
|
from .DiffAE_model_unet import * |
|
from .DiffAE_support_choices import * |
|
|
|
|
|
@dataclass |
|
class BeatGANsAutoencConfig(BeatGANsUNetConfig): |
|
|
|
enc_out_channels: int = 512 |
|
enc_attn_resolutions: Tuple[int] = None |
|
enc_pool: str = 'depthconv' |
|
enc_num_res_block: int = 2 |
|
enc_channel_mult: Tuple[int] = None |
|
enc_grad_checkpoint: bool = False |
|
latent_net_conf: MLPSkipNetConfig = None |
|
|
|
def make_model(self): |
|
return BeatGANsAutoencModel(self) |
|
|
|
|
|
class BeatGANsAutoencModel(BeatGANsUNetModel): |
|
def __init__(self, conf: BeatGANsAutoencConfig): |
|
super().__init__(conf) |
|
self.conf = conf |
|
|
|
|
|
self.time_embed = TimeStyleSeperateEmbed( |
|
time_channels=conf.model_channels, |
|
time_out_channels=conf.embed_channels, |
|
) |
|
|
|
self.encoder = BeatGANsEncoderConfig( |
|
image_size=conf.image_size, |
|
in_channels=conf.in_channels, |
|
model_channels=conf.model_channels, |
|
out_hid_channels=conf.enc_out_channels, |
|
out_channels=conf.enc_out_channels, |
|
num_res_blocks=conf.enc_num_res_block, |
|
attention_resolutions=(conf.enc_attn_resolutions |
|
or conf.attention_resolutions), |
|
dropout=conf.dropout, |
|
channel_mult=conf.enc_channel_mult or conf.channel_mult, |
|
use_time_condition=False, |
|
conv_resample=conf.conv_resample, |
|
group_norm_limit=conf.group_norm_limit, |
|
dims=conf.dims, |
|
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, |
|
num_heads=conf.num_heads, |
|
num_head_channels=conf.num_head_channels, |
|
resblock_updown=conf.resblock_updown, |
|
use_new_attention_order=conf.use_new_attention_order, |
|
pool=conf.enc_pool, |
|
).make_model() |
|
|
|
if conf.latent_net_conf is not None: |
|
self.latent_net = conf.latent_net_conf.make_model() |
|
|
|
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: |
|
""" |
|
Reparameterization trick to sample from N(mu, var) from |
|
N(0,1). |
|
:param mu: (Tensor) Mean of the latent Gaussian [B x D] |
|
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] |
|
:return: (Tensor) [B x D] |
|
""" |
|
assert self.conf.is_stochastic |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return eps * std + mu |
|
|
|
def sample_z(self, n: int, device): |
|
assert self.conf.is_stochastic |
|
return torch.randn(n, self.conf.enc_out_channels, device=device) |
|
|
|
def noise_to_cond(self, noise: Tensor): |
|
raise NotImplementedError() |
|
assert self.conf.noise_net_conf is not None |
|
return self.noise_net.forward(noise) |
|
|
|
def encode(self, x): |
|
cond = self.encoder.forward(x) |
|
return {'cond': cond} |
|
|
|
@property |
|
def stylespace_sizes(self): |
|
modules = list(self.input_blocks.modules()) + list( |
|
self.middle_block.modules()) + list(self.output_blocks.modules()) |
|
sizes = [] |
|
for module in modules: |
|
if isinstance(module, ResBlock): |
|
linear = module.cond_emb_layers[-1] |
|
sizes.append(linear.weight.shape[0]) |
|
return sizes |
|
|
|
def encode_stylespace(self, x, return_vector: bool = True): |
|
""" |
|
encode to style space |
|
""" |
|
modules = list(self.input_blocks.modules()) + list( |
|
self.middle_block.modules()) + list(self.output_blocks.modules()) |
|
|
|
cond = self.encoder.forward(x) |
|
S = [] |
|
for module in modules: |
|
if isinstance(module, ResBlock): |
|
|
|
s = module.cond_emb_layers.forward(cond) |
|
S.append(s) |
|
|
|
if return_vector: |
|
|
|
return torch.cat(S, dim=1) |
|
else: |
|
return S |
|
|
|
def forward(self, |
|
x, |
|
t, |
|
y=None, |
|
x_start=None, |
|
cond=None, |
|
style=None, |
|
noise=None, |
|
t_cond=None, |
|
**kwargs): |
|
""" |
|
Apply the model to an input batch. |
|
|
|
Args: |
|
x_start: the original image to encode |
|
cond: output of the encoder |
|
noise: random noise (to predict the cond) |
|
""" |
|
|
|
if t_cond is None: |
|
t_cond = t |
|
|
|
if noise is not None: |
|
|
|
cond = self.noise_to_cond(noise) |
|
|
|
if cond is None: |
|
if x is not None: |
|
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' |
|
|
|
tmp = self.encode(x_start) |
|
cond = tmp['cond'] |
|
|
|
if t is not None: |
|
_t_emb = timestep_embedding(t, self.conf.model_channels) |
|
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) |
|
else: |
|
|
|
_t_emb = None |
|
_t_cond_emb = None |
|
|
|
if self.conf.resnet_two_cond: |
|
res = self.time_embed.forward( |
|
time_emb=_t_emb, |
|
cond=cond, |
|
time_cond_emb=_t_cond_emb, |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.conf.resnet_two_cond: |
|
|
|
emb = res.time_emb |
|
cond_emb = res.emb |
|
else: |
|
|
|
emb = res.emb |
|
cond_emb = None |
|
|
|
|
|
style = style or res.style |
|
|
|
assert (y is not None) == ( |
|
self.conf.num_classes is not None |
|
), "must specify y if and only if the model is class-conditional" |
|
|
|
if self.conf.num_classes is not None: |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
enc_time_emb = emb |
|
mid_time_emb = emb |
|
dec_time_emb = emb |
|
|
|
enc_cond_emb = cond_emb |
|
mid_cond_emb = cond_emb |
|
dec_cond_emb = cond_emb |
|
|
|
|
|
hs = [[] for _ in range(len(self.conf.channel_mult))] |
|
|
|
if x is not None: |
|
h = x.type(self.dtype) |
|
|
|
|
|
k = 0 |
|
for i in range(len(self.input_num_blocks)): |
|
for j in range(self.input_num_blocks[i]): |
|
h = self.input_blocks[k](h, |
|
emb=enc_time_emb, |
|
cond=enc_cond_emb) |
|
|
|
|
|
hs[i].append(h) |
|
k += 1 |
|
assert k == len(self.input_blocks) |
|
|
|
|
|
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) |
|
else: |
|
|
|
|
|
h = None |
|
hs = [[] for _ in range(len(self.conf.channel_mult))] |
|
|
|
|
|
k = 0 |
|
for i in range(len(self.output_num_blocks)): |
|
for j in range(self.output_num_blocks[i]): |
|
|
|
|
|
try: |
|
lateral = hs[-i - 1].pop() |
|
|
|
except IndexError: |
|
lateral = None |
|
|
|
|
|
h = self.output_blocks[k](h, |
|
emb=dec_time_emb, |
|
cond=dec_cond_emb, |
|
lateral=lateral) |
|
k += 1 |
|
|
|
pred = self.out(h) |
|
return AutoencReturn(pred=pred, cond=cond) |
|
|
|
|
|
class AutoencReturn(NamedTuple): |
|
pred: Tensor |
|
cond: Tensor = None |
|
|
|
|
|
class EmbedReturn(NamedTuple): |
|
|
|
emb: Tensor = None |
|
|
|
time_emb: Tensor = None |
|
|
|
style: Tensor = None |
|
|
|
|
|
class TimeStyleSeperateEmbed(nn.Module): |
|
|
|
def __init__(self, time_channels, time_out_channels): |
|
super().__init__() |
|
self.time_embed = nn.Sequential( |
|
linear(time_channels, time_out_channels), |
|
nn.SiLU(), |
|
linear(time_out_channels, time_out_channels), |
|
) |
|
self.style = nn.Identity() |
|
|
|
def forward(self, time_emb=None, cond=None, **kwargs): |
|
if time_emb is None: |
|
|
|
time_emb = None |
|
else: |
|
time_emb = self.time_embed(time_emb) |
|
style = self.style(cond) |
|
return EmbedReturn(emb=style, time_emb=time_emb, style=style) |
|
|