Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
from .utils import exist, set_default_layer | |
class Identity(nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
def forward(x, *args, **kwargs): | |
return x | |
class SinusoidalPosEmb_sr(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) | |
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j').to(dtype=x.dtype) | |
return torch.cat((emb.sin(), emb.cos()), dim=-1) | |
class UpDownResolution(nn.Module): | |
def __init__(self, num_channels, up_resolution, change_type='conv'): | |
super().__init__() | |
if change_type == 'pooling': | |
self.change_resolution = set_default_layer( | |
up_resolution, | |
layer_1=nn.Upsample, kwargs_1={'scale_factor': 2., 'mode': 'nearest'}, | |
layer_2=nn.AvgPool2d, kwargs_2={'kernel_size': 2, 'stride': 2} | |
) | |
elif change_type == 'conv': | |
self.change_resolution = set_default_layer( | |
up_resolution, | |
nn.ConvTranspose2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1}, | |
nn.Conv2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1}, | |
) | |
else: | |
raise NotImplementedError | |
def forward(self, x): | |
x = self.change_resolution(x) | |
return x | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb) | |
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') | |
return torch.cat((emb.sin(), emb.cos()), dim=-1) | |
class ConditionalGroupNorm(nn.Module): | |
def __init__(self, groups, normalized_shape, context_dim): | |
super().__init__() | |
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) | |
self.context_mlp = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(context_dim, 2 * normalized_shape) | |
) | |
self.context_mlp[1].weight.data.zero_() | |
self.context_mlp[1].bias.data.zero_() | |
def forward(self, x, context): | |
context = self.context_mlp(context) | |
ndims = ' 1' * len(x.shape[2:]) | |
context = rearrange(context, f'b c -> b c{ndims}') | |
scale, shift = context.chunk(2, dim=1) | |
x = self.norm(x) * (scale + 1.) + shift | |
return x | |
class Attention(nn.Module): | |
def __init__(self, in_channels, out_channels, context_dim, head_dim=64): | |
super().__init__() | |
assert out_channels % head_dim == 0 | |
self.num_heads = out_channels // head_dim | |
self.scale = head_dim ** -0.5 | |
self.to_query = nn.Linear(in_channels, out_channels, bias=False) | |
self.to_key = nn.Linear(context_dim, out_channels, bias=False) | |
self.to_value = nn.Linear(context_dim, out_channels, bias=False) | |
self.output_layer = nn.Linear(out_channels, out_channels, bias=False) | |
def forward(self, x, context, context_mask=None): | |
query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads) | |
key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads) | |
value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads) | |
attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale | |
if exist(context_mask): | |
max_neg_value = -torch.finfo(attention_matrix.dtype).max | |
context_mask = rearrange(context_mask, 'b j -> b 1 1 j') | |
attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value) | |
attention_matrix = attention_matrix.softmax(dim=-1) | |
out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
out = self.output_layer(out) | |
return out | |