|
import math |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import NamedTuple, Tuple |
|
|
|
import torch |
|
from .DiffAE_support_choices import * |
|
from .DiffAE_support_config_base import BaseConfig |
|
from torch import nn |
|
from torch.nn import init |
|
|
|
from .DiffAE_model_blocks import * |
|
from .DiffAE_model_nn import timestep_embedding |
|
from .DiffAE_model_unet import * |
|
|
|
|
|
class LatentNetType(Enum): |
|
none = 'none' |
|
|
|
skip = 'skip' |
|
|
|
|
|
class LatentNetReturn(NamedTuple): |
|
pred: torch.Tensor = None |
|
|
|
|
|
@dataclass |
|
class MLPSkipNetConfig(BaseConfig): |
|
""" |
|
default MLP for the latent DPM in the paper! |
|
""" |
|
num_channels: int |
|
skip_layers: Tuple[int] |
|
num_hid_channels: int |
|
num_layers: int |
|
num_time_emb_channels: int = 64 |
|
activation: Activation = Activation.silu |
|
use_norm: bool = True |
|
condition_bias: float = 1 |
|
dropout: float = 0 |
|
last_act: Activation = Activation.none |
|
num_time_layers: int = 2 |
|
time_last_act: bool = False |
|
|
|
def make_model(self): |
|
return MLPSkipNet(self) |
|
|
|
|
|
class MLPSkipNet(nn.Module): |
|
""" |
|
concat x to hidden layers |
|
|
|
default MLP for the latent DPM in the paper! |
|
""" |
|
def __init__(self, conf: MLPSkipNetConfig): |
|
super().__init__() |
|
self.conf = conf |
|
|
|
layers = [] |
|
for i in range(conf.num_time_layers): |
|
if i == 0: |
|
a = conf.num_time_emb_channels |
|
b = conf.num_channels |
|
else: |
|
a = conf.num_channels |
|
b = conf.num_channels |
|
layers.append(nn.Linear(a, b)) |
|
if i < conf.num_time_layers - 1 or conf.time_last_act: |
|
layers.append(conf.activation.get_act()) |
|
self.time_embed = nn.Sequential(*layers) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for i in range(conf.num_layers): |
|
if i == 0: |
|
act = conf.activation |
|
norm = conf.use_norm |
|
cond = True |
|
a, b = conf.num_channels, conf.num_hid_channels |
|
dropout = conf.dropout |
|
elif i == conf.num_layers - 1: |
|
act = Activation.none |
|
norm = False |
|
cond = False |
|
a, b = conf.num_hid_channels, conf.num_channels |
|
dropout = 0 |
|
else: |
|
act = conf.activation |
|
norm = conf.use_norm |
|
cond = True |
|
a, b = conf.num_hid_channels, conf.num_hid_channels |
|
dropout = conf.dropout |
|
|
|
if i in conf.skip_layers: |
|
a += conf.num_channels |
|
|
|
self.layers.append( |
|
MLPLNAct( |
|
a, |
|
b, |
|
norm=norm, |
|
activation=act, |
|
cond_channels=conf.num_channels, |
|
use_cond=cond, |
|
condition_bias=conf.condition_bias, |
|
dropout=dropout, |
|
)) |
|
self.last_act = conf.last_act.get_act() |
|
|
|
def forward(self, x, t, **kwargs): |
|
t = timestep_embedding(t, self.conf.num_time_emb_channels) |
|
cond = self.time_embed(t) |
|
h = x |
|
for i in range(len(self.layers)): |
|
if i in self.conf.skip_layers: |
|
|
|
h = torch.cat([h, x], dim=1) |
|
h = self.layers[i].forward(x=h, cond=cond) |
|
h = self.last_act(h) |
|
return LatentNetReturn(h) |
|
|
|
|
|
class MLPLNAct(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
norm: bool, |
|
use_cond: bool, |
|
activation: Activation, |
|
cond_channels: int, |
|
condition_bias: float = 0, |
|
dropout: float = 0, |
|
): |
|
super().__init__() |
|
self.activation = activation |
|
self.condition_bias = condition_bias |
|
self.use_cond = use_cond |
|
|
|
self.linear = nn.Linear(in_channels, out_channels) |
|
self.act = activation.get_act() |
|
if self.use_cond: |
|
self.linear_emb = nn.Linear(cond_channels, out_channels) |
|
self.cond_layers = nn.Sequential(self.act, self.linear_emb) |
|
if norm: |
|
self.norm = nn.LayerNorm(out_channels) |
|
else: |
|
self.norm = nn.Identity() |
|
|
|
if dropout > 0: |
|
self.dropout = nn.Dropout(p=dropout) |
|
else: |
|
self.dropout = nn.Identity() |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
if self.activation == Activation.relu: |
|
init.kaiming_normal_(module.weight, |
|
a=0, |
|
nonlinearity='relu') |
|
elif self.activation == Activation.lrelu: |
|
init.kaiming_normal_(module.weight, |
|
a=0.2, |
|
nonlinearity='leaky_relu') |
|
elif self.activation == Activation.silu: |
|
init.kaiming_normal_(module.weight, |
|
a=0, |
|
nonlinearity='relu') |
|
else: |
|
|
|
pass |
|
|
|
def forward(self, x, cond=None): |
|
x = self.linear(x) |
|
if self.use_cond: |
|
|
|
cond = self.cond_layers(cond) |
|
cond = (cond, None) |
|
|
|
|
|
x = x * (self.condition_bias + cond[0]) |
|
if cond[1] is not None: |
|
x = x + cond[1] |
|
|
|
x = self.norm(x) |
|
else: |
|
|
|
x = self.norm(x) |
|
x = self.act(x) |
|
x = self.dropout(x) |
|
return x |