|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Pytorch Unet Module used for diffusion. |
|
""" |
|
|
|
from dataclasses import dataclass |
|
import typing as tp |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from .transformer import StreamingTransformer, create_sin_embedding |
|
|
|
|
|
@dataclass |
|
class Output: |
|
sample: torch.Tensor |
|
|
|
|
|
def get_model(cfg, channels: int, side: int, num_steps: int): |
|
if cfg.model == 'unet': |
|
return DiffusionUnet( |
|
chin=channels, num_steps=num_steps, **cfg.diffusion_unet) |
|
else: |
|
raise RuntimeError('Not Implemented') |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, |
|
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
|
dropout: float = 0.): |
|
super().__init__() |
|
stride = 1 |
|
padding = dilation * (kernel - stride) // 2 |
|
Conv = nn.Conv1d |
|
Drop = nn.Dropout1d |
|
self.norm1 = nn.GroupNorm(norm_groups, channels) |
|
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) |
|
self.activation1 = activation() |
|
self.dropout1 = Drop(dropout) |
|
|
|
self.norm2 = nn.GroupNorm(norm_groups, channels) |
|
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) |
|
self.activation2 = activation() |
|
self.dropout2 = Drop(dropout) |
|
|
|
def forward(self, x): |
|
h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) |
|
h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) |
|
return x + h |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, |
|
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
|
dropout: float = 0.): |
|
super().__init__() |
|
padding = (kernel - stride) // 2 |
|
self.res_blocks = nn.Sequential( |
|
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) |
|
for idx in range(res_blocks)]) |
|
self.norm = nn.GroupNorm(norm_groups, chin) |
|
ConvTr = nn.ConvTranspose1d |
|
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) |
|
self.activation = activation() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.res_blocks(x) |
|
x = self.norm(x) |
|
x = self.activation(x) |
|
x = self.convtr(x) |
|
return x |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, |
|
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
|
dropout: float = 0.): |
|
super().__init__() |
|
padding = (kernel - stride) // 2 |
|
Conv = nn.Conv1d |
|
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) |
|
self.norm = nn.GroupNorm(norm_groups, chout) |
|
self.activation = activation() |
|
self.res_blocks = nn.Sequential( |
|
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) |
|
for idx in range(res_blocks)]) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
B, C, T = x.shape |
|
stride, = self.conv.stride |
|
pad = (stride - (T % stride)) % stride |
|
x = F.pad(x, (0, pad)) |
|
|
|
x = self.conv(x) |
|
x = self.norm(x) |
|
x = self.activation(x) |
|
x = self.res_blocks(x) |
|
return x |
|
|
|
|
|
class BLSTM(nn.Module): |
|
"""BiLSTM with same hidden units as input dim. |
|
""" |
|
def __init__(self, dim, layers=2): |
|
super().__init__() |
|
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) |
|
self.linear = nn.Linear(2 * dim, dim) |
|
|
|
def forward(self, x): |
|
x = x.permute(2, 0, 1) |
|
x = self.lstm(x)[0] |
|
x = self.linear(x) |
|
x = x.permute(1, 2, 0) |
|
return x |
|
|
|
|
|
class DiffusionUnet(nn.Module): |
|
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., |
|
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, |
|
bilstm: bool = False, transformer: bool = False, |
|
codec_dim: tp.Optional[int] = None, **kwargs): |
|
super().__init__() |
|
self.encoders = nn.ModuleList() |
|
self.decoders = nn.ModuleList() |
|
self.embeddings: tp.Optional[nn.ModuleList] = None |
|
self.embedding = nn.Embedding(num_steps, hidden) |
|
if emb_all_layers: |
|
self.embeddings = nn.ModuleList() |
|
self.condition_embedding: tp.Optional[nn.Module] = None |
|
for d in range(depth): |
|
encoder = EncoderLayer(chin, hidden, **kwargs) |
|
decoder = DecoderLayer(hidden, chin, **kwargs) |
|
self.encoders.append(encoder) |
|
self.decoders.insert(0, decoder) |
|
if emb_all_layers and d > 0: |
|
assert self.embeddings is not None |
|
self.embeddings.append(nn.Embedding(num_steps, hidden)) |
|
chin = hidden |
|
hidden = min(int(chin * growth), max_channels) |
|
self.bilstm: tp.Optional[nn.Module] |
|
if bilstm: |
|
self.bilstm = BLSTM(chin) |
|
else: |
|
self.bilstm = None |
|
self.use_transformer = transformer |
|
self.cross_attention = False |
|
if transformer: |
|
self.cross_attention = cross_attention |
|
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, |
|
cross_attention=cross_attention) |
|
|
|
self.use_codec = False |
|
if codec_dim is not None: |
|
self.conv_codec = nn.Conv1d(codec_dim, chin, 1) |
|
self.use_codec = True |
|
|
|
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): |
|
skips = [] |
|
bs = x.size(0) |
|
z = x |
|
view_args = [1] |
|
if type(step) is torch.Tensor: |
|
step_tensor = step |
|
else: |
|
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) |
|
|
|
for idx, encoder in enumerate(self.encoders): |
|
z = encoder(z) |
|
if idx == 0: |
|
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) |
|
elif self.embeddings is not None: |
|
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) |
|
|
|
skips.append(z) |
|
|
|
if self.use_codec: |
|
assert condition is not None, "Model defined for conditionnal generation" |
|
condition_emb = self.conv_codec(condition) |
|
assert condition_emb.size(-1) <= 2 * z.size(-1), \ |
|
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" |
|
if not self.cross_attention: |
|
|
|
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) |
|
assert z.size() == condition_emb.size() |
|
z += condition_emb |
|
cross_attention_src = None |
|
else: |
|
cross_attention_src = condition_emb.permute(0, 2, 1) |
|
B, T, C = cross_attention_src.shape |
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) |
|
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) |
|
cross_attention_src = cross_attention_src + pos_emb |
|
if self.use_transformer: |
|
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) |
|
else: |
|
if self.bilstm is None: |
|
z = torch.zeros_like(z) |
|
else: |
|
z = self.bilstm(z) |
|
|
|
for decoder in self.decoders: |
|
s = skips.pop(-1) |
|
z = z[:, :, :s.shape[2]] |
|
z = z + s |
|
z = decoder(z) |
|
|
|
z = z[:, :, :x.shape[2]] |
|
return Output(z) |
|
|