File size: 7,202 Bytes
0dee401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import math
import numpy as np
import torch
import torch.nn as nn
class ModulatedLayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-6, channels_first=True):
self.ln = nn.LayerNorm(num_features, eps=eps)
self.gamma = nn.Parameter(torch.randn(1, 1, 1))
self.beta = nn.Parameter(torch.randn(1, 1, 1))
self.channels_first = channels_first
def forward(self, x, w=None):
x = x.permute(0, 2, 3, 1) if self.channels_first else x
if w is None:
x = self.ln(x)
x = self.gamma * w * self.ln(x) + self.beta * w
x = x.permute(0, 3, 1, 2) if self.channels_first else x
return x
class ResBlock(nn.Module):
def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6):
self.depthwise = nn.Sequential(
nn.Conv2d(c, c, kernel_size=3, groups=c)
self.ln = ModulatedLayerNorm(c, channels_first=False)
self.channelwise = nn.Sequential(
nn.Linear(c + c_skip, c_hidden),
nn.Linear(c_hidden, c),
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) if layer_scale_init_value > 0 else None
self.scaler = scaler
if c_cond > 0:
self.cond_mapper = nn.Linear(c_cond, c)
def forward(self, x, s=None, skip=None):
res = x
x = self.depthwise(x)
if s is not None:
if s.size(2) == s.size(3) == 1:
s = s.expand(-1, -1, x.size(2), x.size(3))
elif s.size(2) != x.size(2) or s.size(3) != x.size(3):
s = nn.functional.interpolate(s, size=x.shape[-2:], mode='bilinear')
s = self.cond_mapper(s.permute(0, 2, 3, 1))
# s = self.cond_mapper(s.permute(0, 2, 3, 1))
# if s.size(1) == s.size(2) == 1:
# s = s.expand(-1, x.size(2), x.size(3), -1)
x = self.ln(x.permute(0, 2, 3, 1), s)
if skip is not None:
x =[x, skip.permute(0, 2, 3, 1)], dim=-1)
x = self.channelwise(x)
x = self.gamma * x if self.gamma is not None else x
x = res + x.permute(0, 3, 1, 2)
if self.scaler is not None:
x = self.scaler(x)
return x
class DenoiseUNet(nn.Module):
def __init__(self, num_labels, c_hidden=1280, c_clip=1024, c_r=64, down_levels=[4, 8, 16], up_levels=[16, 8, 4]):
self.num_labels = num_labels
self.c_r = c_r
self.down_levels = down_levels
self.up_levels = up_levels
c_levels = [c_hidden // (2 ** i) for i in reversed(range(len(down_levels)))]
self.embedding = nn.Embedding(num_labels, c_levels[0])
self.down_blocks = nn.ModuleList()
for i, num_blocks in enumerate(down_levels):
blocks = []
if i > 0:
blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
for _ in range(num_blocks):
block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r)
block.channelwise[-1] *= np.sqrt(1 / sum(down_levels))
self.up_blocks = nn.ModuleList()
for i, num_blocks in enumerate(up_levels):
blocks = []
for j in range(num_blocks):
block = ResBlock(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 1 - i] * 4, c_clip + c_r,
c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0)
block.channelwise[-1] *= np.sqrt(1 / sum(up_levels))
if i < len(up_levels) - 1:
nn.ConvTranspose2d(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 2 - i], kernel_size=4, stride=2, padding=1))
self.clf = nn.Conv2d(c_levels[0], num_labels, kernel_size=1)
def gamma(self, r):
return (r * torch.pi / 2).cos()
def add_noise(self, x, r, random_x=None):
r = self.gamma(r)[:, None, None]
mask = torch.bernoulli(r * torch.ones_like(x), )
mask = mask.round().long()
if random_x is None:
random_x = torch.randint_like(x, 0, self.num_labels)
x = x * (1 - mask) + random_x * mask
return x, mask
def gen_r_embedding(self, r, max_positions=10000):
dtype = r.dtype
r = self.gamma(r) * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb =[emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
def _down_encode_(self, x, s):
level_outputs = []
for i, blocks in enumerate(self.down_blocks):
for block in blocks:
if isinstance(block, ResBlock):
# s_level = s[:, 0]
# s = s[:, 1:]
x = block(x, s)
x = block(x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, s):
x = level_outputs[0]
for i, blocks in enumerate(self.up_blocks):
for j, block in enumerate(blocks):
if isinstance(block, ResBlock):
# s_level = s[:, 0]
# s = s[:, 1:]
if i > 0 and j == 0:
x = block(x, s, level_outputs[i])
x = block(x, s)
x = block(x)
return x
def forward(self, x, c, r): # r is a uniform value between 0 and 1
r_embed = self.gen_r_embedding(r)
x = self.embedding(x).permute(0, 3, 1, 2)
if len(c.shape) == 2:
s =[c, r_embed], dim=-1)[:, :, None, None]
r_embed = r_embed[:, :, None, None].expand(-1, -1, c.size(2), c.size(3))
s =[c, r_embed], dim=1)
level_outputs = self._down_encode_(x, s)
x = self._up_decode(level_outputs, s)
x = self.clf(x)
return x
if __name__ == '__main__':
device = "cuda"
model = DenoiseUNet(1024).to(device)
print(sum([p.numel() for p in model.parameters()]))
x = torch.randint(0, 1024, (1, 32, 32)).long().to(device)
c = torch.randn((1, 1024)).to(device)
r = torch.rand(1).to(device)
model(x, c, r)