mueller-franzes's picture
init
f85e212
import math
import torch
import torch.nn as nn
from monai.networks.layers.utils import get_act_layer
class SinusoidalPosEmb(nn.Module):
def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
super().__init__()
self.emb_dim = emb_dim
self.downscale_freq_shift = downscale_freq_shift
self.max_period = max_period
self.flip_sin_to_cos=flip_sin_to_cos
def forward(self, x):
device = x.device
half_dim = self.emb_dim // 2
emb = math.log(self.max_period) / (half_dim - self.downscale_freq_shift)
emb = torch.exp(-emb*torch.arange(half_dim, device=device))
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
if self.flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if self.emb_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, emb_dim):
super().__init__()
self.emb_dim = emb_dim
half_dim = emb_dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = x[:, None]
freqs = x * self.weights[None, :] * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
if self.emb_dim % 2 == 1:
fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
return fouriered
class TimeEmbbeding(nn.Module):
def __init__(
self,
emb_dim = 64,
pos_embedder = SinusoidalPosEmb,
pos_embedder_kwargs = {},
act_name=("SWISH", {}) # Swish = SiLU
):
super().__init__()
self.emb_dim = emb_dim
self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
self.time_emb = nn.Sequential(
self.pos_embedder,
nn.Linear(self.pos_emb_dim, self.emb_dim),
get_act_layer(act_name),
nn.Linear(self.emb_dim, self.emb_dim)
)
def forward(self, time):
return self.time_emb(time)