mridulk commited on
Commit
d759c1a
1 Parent(s): d39ef0a

added few more files

Browse files
Files changed (31) hide show
  1. ldm/modules/attention.py +261 -0
  2. ldm/modules/diffusionmodules/__init__.py +0 -0
  3. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  4. ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  5. ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
  6. ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
  7. ldm/modules/diffusionmodules/model.py +840 -0
  8. ldm/modules/diffusionmodules/openaimodel.py +963 -0
  9. ldm/modules/diffusionmodules/util.py +267 -0
  10. ldm/modules/discriminator/__pycache__/model.cpython-38.pyc +0 -0
  11. ldm/modules/discriminator/model.py +69 -0
  12. ldm/modules/distributions/__init__.py +0 -0
  13. ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc +0 -0
  14. ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc +0 -0
  15. ldm/modules/distributions/distributions.py +92 -0
  16. ldm/modules/ema.py +76 -0
  17. ldm/modules/encoders/__init__.py +0 -0
  18. ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc +0 -0
  19. ldm/modules/encoders/__pycache__/modules.cpython-38.pyc +0 -0
  20. ldm/modules/encoders/modules.py +404 -0
  21. ldm/modules/image_degradation/__init__.py +2 -0
  22. ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc +0 -0
  23. ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc +0 -0
  24. ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc +0 -0
  25. ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc +0 -0
  26. ldm/modules/image_degradation/bsrgan.py +730 -0
  27. ldm/modules/image_degradation/bsrgan_light.py +650 -0
  28. ldm/modules/image_degradation/utils/test.png +0 -0
  29. ldm/modules/image_degradation/utils_image.py +916 -0
  30. ldm/modules/util.py +86 -0
  31. ldm/modules/x_transformer.py +641 -0
ldm/modules/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198
+ super().__init__()
199
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
+ self.norm1 = nn.LayerNorm(dim)
204
+ self.norm2 = nn.LayerNorm(dim)
205
+ self.norm3 = nn.LayerNorm(dim)
206
+ self.checkpoint = checkpoint
207
+
208
+ def forward(self, x, context=None):
209
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
+
211
+ def _forward(self, x, context=None):
212
+ x = self.attn1(self.norm1(x)) + x
213
+ x = self.attn2(self.norm2(x), context=context) + x
214
+ x = self.ff(self.norm3(x)) + x
215
+ return x
216
+
217
+
218
+ class SpatialTransformer(nn.Module):
219
+ """
220
+ Transformer block for image-like data.
221
+ First, project the input (aka embedding)
222
+ and reshape to b, t, d.
223
+ Then apply standard transformer action.
224
+ Finally, reshape to image
225
+ """
226
+ def __init__(self, in_channels, n_heads, d_head,
227
+ depth=1, dropout=0., context_dim=None):
228
+ super().__init__()
229
+ self.in_channels = in_channels
230
+ inner_dim = n_heads * d_head
231
+ self.norm = Normalize(in_channels)
232
+
233
+ self.proj_in = nn.Conv2d(in_channels,
234
+ inner_dim,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+
239
+ self.transformer_blocks = nn.ModuleList(
240
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241
+ for d in range(depth)]
242
+ )
243
+
244
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
+ in_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0))
249
+
250
+ def forward(self, x, context=None):
251
+ # note: if no context is given, cross-attention defaults to self-attention
252
+ b, c, h, w = x.shape
253
+ x_in = x
254
+ x = self.norm(x)
255
+ x = self.proj_in(x)
256
+ x = rearrange(x, 'b c h w -> b (h w) c')
257
+ for block in self.transformer_blocks:
258
+ x = block(x, context=context)
259
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
+ x = self.proj_out(x)
261
+ return x + x_in
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (161 Bytes). View file
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc ADDED
Binary file (22.8 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc ADDED
Binary file (9.44 kB). View file
 
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ # if context is not None:
319
+ # # assume aligned context, cat along channel axis
320
+ # x = torch.cat((x, context), dim=1)
321
+ # the three lines were commented out before commiting the code
322
+ if context is not None:
323
+ # assume aligned context, cat along channel axis
324
+ x = torch.cat((x, context), dim=1)
325
+ if self.use_timestep:
326
+ # timestep embedding
327
+ assert t is not None
328
+ temb = get_timestep_embedding(t, self.ch)
329
+ temb = self.temb.dense[0](temb)
330
+ temb = nonlinearity(temb)
331
+ temb = self.temb.dense[1](temb)
332
+ else:
333
+ temb = None
334
+
335
+ # downsampling
336
+ hs = [self.conv_in(x)]
337
+ for i_level in range(self.num_resolutions):
338
+ for i_block in range(self.num_res_blocks):
339
+ h = self.down[i_level].block[i_block](hs[-1], temb)
340
+ if len(self.down[i_level].attn) > 0:
341
+ h = self.down[i_level].attn[i_block](h)
342
+ hs.append(h)
343
+ if i_level != self.num_resolutions-1:
344
+ hs.append(self.down[i_level].downsample(hs[-1]))
345
+
346
+ # middle
347
+ h = hs[-1]
348
+ h = self.mid.block_1(h, temb)
349
+ h = self.mid.attn_1(h)
350
+ h = self.mid.block_2(h, temb)
351
+
352
+ # upsampling
353
+ for i_level in reversed(range(self.num_resolutions)):
354
+ for i_block in range(self.num_res_blocks+1):
355
+ h = self.up[i_level].block[i_block](
356
+ torch.cat([h, hs.pop()], dim=1), temb)
357
+ if len(self.up[i_level].attn) > 0:
358
+ h = self.up[i_level].attn[i_block](h)
359
+ if i_level != 0:
360
+ h = self.up[i_level].upsample(h)
361
+
362
+ # end
363
+ h = self.norm_out(h)
364
+ h = nonlinearity(h)
365
+ h = self.conv_out(h)
366
+ return h
367
+
368
+ def get_last_layer(self):
369
+ return self.conv_out.weight
370
+
371
+
372
+ class Encoder(nn.Module):
373
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
374
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
375
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
376
+ **ignore_kwargs):
377
+ super().__init__()
378
+ if use_linear_attn: attn_type = "linear"
379
+ self.ch = ch
380
+ self.temb_ch = 0
381
+ self.num_resolutions = len(ch_mult)
382
+ self.num_res_blocks = num_res_blocks
383
+ self.resolution = resolution
384
+ self.in_channels = in_channels
385
+ self.block_in = None
386
+
387
+ # downsampling
388
+ self.conv_in = torch.nn.Conv2d(in_channels,
389
+ self.ch,
390
+ kernel_size=3,
391
+ stride=1,
392
+ padding=1)
393
+
394
+ curr_res = resolution
395
+ in_ch_mult = (1,)+tuple(ch_mult)
396
+ self.in_ch_mult = in_ch_mult
397
+ self.down = nn.ModuleList()
398
+ for i_level in range(self.num_resolutions):
399
+ block = nn.ModuleList()
400
+ attn = nn.ModuleList()
401
+ self.block_in = ch*in_ch_mult[i_level]
402
+ block_out = ch*ch_mult[i_level]
403
+ for i_block in range(self.num_res_blocks):
404
+ block.append(ResnetBlock(in_channels=self.block_in,
405
+ out_channels=block_out,
406
+ temb_channels=self.temb_ch,
407
+ dropout=dropout))
408
+ self.block_in = block_out
409
+ if curr_res in attn_resolutions:
410
+ attn.append(make_attn(self.block_in, attn_type=attn_type))
411
+ down = nn.Module()
412
+ down.block = block
413
+ down.attn = attn
414
+ if i_level != self.num_resolutions-1:
415
+ down.downsample = Downsample(self.block_in, resamp_with_conv)
416
+ curr_res = curr_res // 2
417
+ self.down.append(down)
418
+
419
+ # middle
420
+ self.mid = nn.Module()
421
+ self.mid.block_1 = ResnetBlock(in_channels=self.block_in,
422
+ out_channels=self.block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+ self.mid.attn_1 = make_attn(self.block_in, attn_type=attn_type)
426
+ self.mid.block_2 = ResnetBlock(in_channels=self.block_in,
427
+ out_channels=self.block_in,
428
+ temb_channels=self.temb_ch,
429
+ dropout=dropout)
430
+
431
+ # end
432
+ self.norm_out = Normalize(self.block_in)
433
+ self.conv_out = torch.nn.Conv2d(self.block_in,
434
+ 2*z_channels if double_z else z_channels,
435
+ kernel_size=3,
436
+ stride=1,
437
+ padding=1)
438
+
439
+ def forward(self, x):
440
+ # timestep embedding
441
+ temb = None
442
+
443
+ # downsampling
444
+ hs = [self.conv_in(x)]
445
+ for i_level in range(self.num_resolutions):
446
+ for i_block in range(self.num_res_blocks):
447
+ h = self.down[i_level].block[i_block](hs[-1], temb)
448
+ if len(self.down[i_level].attn) > 0:
449
+ h = self.down[i_level].attn[i_block](h)
450
+ hs.append(h)
451
+ if i_level != self.num_resolutions-1:
452
+ hs.append(self.down[i_level].downsample(hs[-1]))
453
+
454
+ # middle
455
+ h = hs[-1]
456
+ h = self.mid.block_1(h, temb)
457
+ h = self.mid.attn_1(h)
458
+ h = self.mid.block_2(h, temb)
459
+
460
+ # end
461
+ h = self.norm_out(h)
462
+ h = nonlinearity(h)
463
+ h = self.conv_out(h)
464
+ return h
465
+
466
+
467
+ class Decoder(nn.Module):
468
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
469
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
470
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
471
+ attn_type="vanilla", **ignorekwargs):
472
+ super().__init__()
473
+ if use_linear_attn: attn_type = "linear"
474
+ self.ch = ch
475
+ self.temb_ch = 0
476
+ self.num_resolutions = len(ch_mult)
477
+ self.num_res_blocks = num_res_blocks
478
+ self.resolution = resolution
479
+ self.in_channels = in_channels
480
+ self.give_pre_end = give_pre_end
481
+ self.tanh_out = tanh_out
482
+
483
+ # compute in_ch_mult, block_in and curr_res at lowest res
484
+ in_ch_mult = (1,)+tuple(ch_mult)
485
+ block_in = ch*ch_mult[self.num_resolutions-1]
486
+ curr_res = resolution // 2**(self.num_resolutions-1)
487
+ self.z_shape = (1,z_channels,curr_res,curr_res)
488
+ print("Working with z of shape {} = {} dimensions.".format(
489
+ self.z_shape, np.prod(self.z_shape)))
490
+
491
+ # z to block_in
492
+ self.conv_in = torch.nn.Conv2d(z_channels,
493
+ block_in,
494
+ kernel_size=3,
495
+ stride=1,
496
+ padding=1)
497
+
498
+ # middle
499
+ self.mid = nn.Module()
500
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
505
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
506
+ out_channels=block_in,
507
+ temb_channels=self.temb_ch,
508
+ dropout=dropout)
509
+
510
+ # upsampling
511
+ self.up = nn.ModuleList()
512
+ for i_level in reversed(range(self.num_resolutions)):
513
+ block = nn.ModuleList()
514
+ attn = nn.ModuleList()
515
+ block_out = ch*ch_mult[i_level]
516
+ for i_block in range(self.num_res_blocks+1):
517
+ block.append(ResnetBlock(in_channels=block_in,
518
+ out_channels=block_out,
519
+ temb_channels=self.temb_ch,
520
+ dropout=dropout))
521
+ block_in = block_out
522
+ if curr_res in attn_resolutions:
523
+ attn.append(make_attn(block_in, attn_type=attn_type))
524
+ up = nn.Module()
525
+ up.block = block
526
+ up.attn = attn
527
+ if i_level != 0:
528
+ up.upsample = Upsample(block_in, resamp_with_conv)
529
+ curr_res = curr_res * 2
530
+ self.up.insert(0, up) # prepend to get consistent order
531
+
532
+ # end
533
+ self.norm_out = Normalize(block_in)
534
+ self.conv_out = torch.nn.Conv2d(block_in,
535
+ out_ch,
536
+ kernel_size=3,
537
+ stride=1,
538
+ padding=1)
539
+
540
+ def forward(self, z):
541
+ #assert z.shape[1:] == self.z_shape[1:]
542
+ self.last_z_shape = z.shape
543
+
544
+ # timestep embedding
545
+ temb = None
546
+
547
+ # z to block_in
548
+ h = self.conv_in(z)
549
+
550
+ # middle
551
+ h = self.mid.block_1(h, temb)
552
+ h = self.mid.attn_1(h)
553
+ h = self.mid.block_2(h, temb)
554
+
555
+ # upsampling
556
+ for i_level in reversed(range(self.num_resolutions)):
557
+ for i_block in range(self.num_res_blocks+1):
558
+ h = self.up[i_level].block[i_block](h, temb)
559
+ if len(self.up[i_level].attn) > 0:
560
+ h = self.up[i_level].attn[i_block](h)
561
+ if i_level != 0:
562
+ h = self.up[i_level].upsample(h)
563
+
564
+ # end
565
+ if self.give_pre_end:
566
+ return h
567
+
568
+ h = self.norm_out(h)
569
+ h = nonlinearity(h)
570
+ h = self.conv_out(h)
571
+ if self.tanh_out:
572
+ h = torch.tanh(h)
573
+ return h
574
+
575
+
576
+ class SimpleDecoder(nn.Module):
577
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
578
+ super().__init__()
579
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
580
+ ResnetBlock(in_channels=in_channels,
581
+ out_channels=2 * in_channels,
582
+ temb_channels=0, dropout=0.0),
583
+ ResnetBlock(in_channels=2 * in_channels,
584
+ out_channels=4 * in_channels,
585
+ temb_channels=0, dropout=0.0),
586
+ ResnetBlock(in_channels=4 * in_channels,
587
+ out_channels=2 * in_channels,
588
+ temb_channels=0, dropout=0.0),
589
+ nn.Conv2d(2*in_channels, in_channels, 1),
590
+ Upsample(in_channels, with_conv=True)])
591
+ # end
592
+ self.norm_out = Normalize(in_channels)
593
+ self.conv_out = torch.nn.Conv2d(in_channels,
594
+ out_channels,
595
+ kernel_size=3,
596
+ stride=1,
597
+ padding=1)
598
+
599
+ def forward(self, x):
600
+ for i, layer in enumerate(self.model):
601
+ if i in [1,2,3]:
602
+ x = layer(x, None)
603
+ else:
604
+ x = layer(x)
605
+
606
+ h = self.norm_out(x)
607
+ h = nonlinearity(h)
608
+ x = self.conv_out(h)
609
+ return x
610
+
611
+
612
+ class UpsampleDecoder(nn.Module):
613
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
614
+ ch_mult=(2,2), dropout=0.0):
615
+ super().__init__()
616
+ # upsampling
617
+ self.temb_ch = 0
618
+ self.num_resolutions = len(ch_mult)
619
+ self.num_res_blocks = num_res_blocks
620
+ block_in = in_channels
621
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
622
+ self.res_blocks = nn.ModuleList()
623
+ self.upsample_blocks = nn.ModuleList()
624
+ for i_level in range(self.num_resolutions):
625
+ res_block = []
626
+ block_out = ch * ch_mult[i_level]
627
+ for i_block in range(self.num_res_blocks + 1):
628
+ res_block.append(ResnetBlock(in_channels=block_in,
629
+ out_channels=block_out,
630
+ temb_channels=self.temb_ch,
631
+ dropout=dropout))
632
+ block_in = block_out
633
+ self.res_blocks.append(nn.ModuleList(res_block))
634
+ if i_level != self.num_resolutions - 1:
635
+ self.upsample_blocks.append(Upsample(block_in, True))
636
+ curr_res = curr_res * 2
637
+
638
+ # end
639
+ self.norm_out = Normalize(block_in)
640
+ self.conv_out = torch.nn.Conv2d(block_in,
641
+ out_channels,
642
+ kernel_size=3,
643
+ stride=1,
644
+ padding=1)
645
+
646
+ def forward(self, x):
647
+ # upsampling
648
+ h = x
649
+ for k, i_level in enumerate(range(self.num_resolutions)):
650
+ for i_block in range(self.num_res_blocks + 1):
651
+ h = self.res_blocks[i_level][i_block](h, None)
652
+ if i_level != self.num_resolutions - 1:
653
+ h = self.upsample_blocks[k](h)
654
+ h = self.norm_out(h)
655
+ h = nonlinearity(h)
656
+ h = self.conv_out(h)
657
+ return h
658
+
659
+
660
+ class LatentRescaler(nn.Module):
661
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
662
+ super().__init__()
663
+ # residual block, interpolate, residual block
664
+ self.factor = factor
665
+ self.conv_in = nn.Conv2d(in_channels,
666
+ mid_channels,
667
+ kernel_size=3,
668
+ stride=1,
669
+ padding=1)
670
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+ self.attn = AttnBlock(mid_channels)
675
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
676
+ out_channels=mid_channels,
677
+ temb_channels=0,
678
+ dropout=0.0) for _ in range(depth)])
679
+
680
+ self.conv_out = nn.Conv2d(mid_channels,
681
+ out_channels,
682
+ kernel_size=1,
683
+ )
684
+
685
+ def forward(self, x):
686
+ x = self.conv_in(x)
687
+ for block in self.res_block1:
688
+ x = block(x, None)
689
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
690
+ x = self.attn(x)
691
+ for block in self.res_block2:
692
+ x = block(x, None)
693
+ x = self.conv_out(x)
694
+ return x
695
+
696
+
697
+ class MergedRescaleEncoder(nn.Module):
698
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
699
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
700
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
701
+ super().__init__()
702
+ intermediate_chn = ch * ch_mult[-1]
703
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
704
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
705
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
706
+ out_ch=None)
707
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
708
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
709
+
710
+ def forward(self, x):
711
+ x = self.encoder(x)
712
+ x = self.rescaler(x)
713
+ return x
714
+
715
+
716
+ class MergedRescaleDecoder(nn.Module):
717
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
718
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
719
+ super().__init__()
720
+ tmp_chn = z_channels*ch_mult[-1]
721
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
722
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
723
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
724
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
725
+ out_channels=tmp_chn, depth=rescale_module_depth)
726
+
727
+ def forward(self, x):
728
+ x = self.rescaler(x)
729
+ x = self.decoder(x)
730
+ return x
731
+
732
+
733
+ class Upsampler(nn.Module):
734
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
735
+ super().__init__()
736
+ assert out_size >= in_size
737
+ num_blocks = int(np.log2(out_size//in_size))+1
738
+ factor_up = 1.+ (out_size % in_size)
739
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
740
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
741
+ out_channels=in_channels)
742
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
743
+ attn_resolutions=[], in_channels=None, ch=in_channels,
744
+ ch_mult=[ch_mult for _ in range(num_blocks)])
745
+
746
+ def forward(self, x):
747
+ x = self.rescaler(x)
748
+ x = self.decoder(x)
749
+ return x
750
+
751
+
752
+ class Resize(nn.Module):
753
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
754
+ super().__init__()
755
+ self.with_conv = learned
756
+ self.mode = mode
757
+ if self.with_conv:
758
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
759
+ raise NotImplementedError()
760
+ assert in_channels is not None
761
+ # no asymmetric padding in torch conv, must do it ourselves
762
+ self.conv = torch.nn.Conv2d(in_channels,
763
+ in_channels,
764
+ kernel_size=4,
765
+ stride=2,
766
+ padding=1)
767
+
768
+ def forward(self, x, scale_factor=1.0):
769
+ if scale_factor==1.0:
770
+ return x
771
+ else:
772
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
773
+ return x
774
+
775
+ class FirstStagePostProcessor(nn.Module):
776
+
777
+ def __init__(self, ch_mult:list, in_channels,
778
+ pretrained_model:nn.Module=None,
779
+ reshape=False,
780
+ n_channels=None,
781
+ dropout=0.,
782
+ pretrained_config=None):
783
+ super().__init__()
784
+ if pretrained_config is None:
785
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
786
+ self.pretrained_model = pretrained_model
787
+ else:
788
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
789
+ self.instantiate_pretrained(pretrained_config)
790
+
791
+ self.do_reshape = reshape
792
+
793
+ if n_channels is None:
794
+ n_channels = self.pretrained_model.encoder.ch
795
+
796
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
797
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
798
+ stride=1,padding=1)
799
+
800
+ blocks = []
801
+ downs = []
802
+ ch_in = n_channels
803
+ for m in ch_mult:
804
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
805
+ ch_in = m * n_channels
806
+ downs.append(Downsample(ch_in, with_conv=False))
807
+
808
+ self.model = nn.ModuleList(blocks)
809
+ self.downsampler = nn.ModuleList(downs)
810
+
811
+
812
+ def instantiate_pretrained(self, config):
813
+ model = instantiate_from_config(config)
814
+ self.pretrained_model = model.eval()
815
+ # self.pretrained_model.train = False
816
+ for param in self.pretrained_model.parameters():
817
+ param.requires_grad = False
818
+
819
+
820
+ @torch.no_grad()
821
+ def encode_with_pretrained(self,x):
822
+ c = self.pretrained_model.encode(x)
823
+ if isinstance(c, DiagonalGaussianDistribution):
824
+ c = c.mode()
825
+ return c
826
+
827
+ def forward(self,x):
828
+ z_fs = self.encode_with_pretrained(x)
829
+ z = self.proj_norm(z_fs)
830
+ z = self.proj(z)
831
+ z = nonlinearity(z)
832
+
833
+ for submodel, downmodel in zip(self.model,self.downsampler):
834
+ z = submodel(z,temb=None)
835
+ z = downmodel(z)
836
+
837
+ if self.do_reshape:
838
+ z = rearrange(z,'b c h w -> b (h w) c')
839
+ return z
840
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size,
446
+ in_channels,
447
+ model_channels,
448
+ out_channels,
449
+ num_res_blocks,
450
+ attention_resolutions,
451
+ dropout=0,
452
+ channel_mult=(1, 2, 4, 8),
453
+ conv_resample=True,
454
+ dims=2,
455
+ num_classes=None,
456
+ use_checkpoint=False,
457
+ use_fp16=False,
458
+ num_heads=-1,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
+ ):
470
+ super().__init__()
471
+ if use_spatial_transformer:
472
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
+
474
+ if context_dim is not None:
475
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
+ from omegaconf.listconfig import ListConfig
477
+ if type(context_dim) == ListConfig:
478
+ context_dim = list(context_dim)
479
+
480
+ if num_heads_upsample == -1:
481
+ num_heads_upsample = num_heads
482
+
483
+ if num_heads == -1:
484
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
485
+
486
+ if num_head_channels == -1:
487
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ self.image_size = image_size
490
+ self.in_channels = in_channels
491
+ self.model_channels = model_channels
492
+ self.out_channels = out_channels
493
+ self.num_res_blocks = num_res_blocks
494
+ self.attention_resolutions = attention_resolutions
495
+ self.dropout = dropout
496
+ self.channel_mult = channel_mult
497
+ self.conv_resample = conv_resample
498
+ self.num_classes = num_classes
499
+ self.use_checkpoint = use_checkpoint
500
+ self.dtype = th.float16 if use_fp16 else th.float32
501
+ self.num_heads = num_heads
502
+ self.num_head_channels = num_head_channels
503
+ self.num_heads_upsample = num_heads_upsample
504
+ self.predict_codebook_ids = n_embed is not None
505
+
506
+ time_embed_dim = model_channels * 4
507
+ self.time_embed = nn.Sequential(
508
+ linear(model_channels, time_embed_dim),
509
+ nn.SiLU(),
510
+ linear(time_embed_dim, time_embed_dim),
511
+ )
512
+
513
+ if self.num_classes is not None:
514
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
515
+
516
+ self.input_blocks = nn.ModuleList(
517
+ [
518
+ TimestepEmbedSequential(
519
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
520
+ )
521
+ ]
522
+ )
523
+ self._feature_size = model_channels
524
+ input_block_chans = [model_channels]
525
+ ch = model_channels
526
+ ds = 1
527
+ for level, mult in enumerate(channel_mult):
528
+ for _ in range(num_res_blocks):
529
+ layers = [
530
+ ResBlock(
531
+ ch,
532
+ time_embed_dim,
533
+ dropout,
534
+ out_channels=mult * model_channels,
535
+ dims=dims,
536
+ use_checkpoint=use_checkpoint,
537
+ use_scale_shift_norm=use_scale_shift_norm,
538
+ )
539
+ ]
540
+ ch = mult * model_channels
541
+ if ds in attention_resolutions:
542
+ if num_head_channels == -1:
543
+ dim_head = ch // num_heads
544
+ else:
545
+ num_heads = ch // num_head_channels
546
+ dim_head = num_head_channels
547
+ if legacy:
548
+ #num_heads = 1
549
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
550
+ layers.append(
551
+ AttentionBlock(
552
+ ch,
553
+ use_checkpoint=use_checkpoint,
554
+ num_heads=num_heads,
555
+ num_head_channels=dim_head,
556
+ use_new_attention_order=use_new_attention_order,
557
+ ) if not use_spatial_transformer else SpatialTransformer(
558
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
559
+ )
560
+ )
561
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
562
+ self._feature_size += ch
563
+ input_block_chans.append(ch)
564
+ if level != len(channel_mult) - 1:
565
+ out_ch = ch
566
+ self.input_blocks.append(
567
+ TimestepEmbedSequential(
568
+ ResBlock(
569
+ ch,
570
+ time_embed_dim,
571
+ dropout,
572
+ out_channels=out_ch,
573
+ dims=dims,
574
+ use_checkpoint=use_checkpoint,
575
+ use_scale_shift_norm=use_scale_shift_norm,
576
+ down=True,
577
+ )
578
+ if resblock_updown
579
+ else Downsample(
580
+ ch, conv_resample, dims=dims, out_channels=out_ch
581
+ )
582
+ )
583
+ )
584
+ ch = out_ch
585
+ input_block_chans.append(ch)
586
+ ds *= 2
587
+ self._feature_size += ch
588
+
589
+ if num_head_channels == -1:
590
+ dim_head = ch // num_heads
591
+ else:
592
+ num_heads = ch // num_head_channels
593
+ dim_head = num_head_channels
594
+ if legacy:
595
+ #num_heads = 1
596
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
597
+ self.middle_block = TimestepEmbedSequential(
598
+ ResBlock(
599
+ ch,
600
+ time_embed_dim,
601
+ dropout,
602
+ dims=dims,
603
+ use_checkpoint=use_checkpoint,
604
+ use_scale_shift_norm=use_scale_shift_norm,
605
+ ),
606
+ AttentionBlock(
607
+ ch,
608
+ use_checkpoint=use_checkpoint,
609
+ num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
+ use_new_attention_order=use_new_attention_order,
612
+ ) if not use_spatial_transformer else SpatialTransformer(
613
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
614
+ ),
615
+ ResBlock(
616
+ ch,
617
+ time_embed_dim,
618
+ dropout,
619
+ dims=dims,
620
+ use_checkpoint=use_checkpoint,
621
+ use_scale_shift_norm=use_scale_shift_norm,
622
+ ),
623
+ )
624
+ self._feature_size += ch
625
+
626
+ self.output_blocks = nn.ModuleList([])
627
+ for level, mult in list(enumerate(channel_mult))[::-1]:
628
+ for i in range(num_res_blocks + 1):
629
+ ich = input_block_chans.pop()
630
+ layers = [
631
+ ResBlock(
632
+ ch + ich,
633
+ time_embed_dim,
634
+ dropout,
635
+ out_channels=model_channels * mult,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ )
640
+ ]
641
+ ch = model_channels * mult
642
+ if ds in attention_resolutions:
643
+ if num_head_channels == -1:
644
+ dim_head = ch // num_heads
645
+ else:
646
+ num_heads = ch // num_head_channels
647
+ dim_head = num_head_channels
648
+ if legacy:
649
+ #num_heads = 1
650
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
651
+ layers.append(
652
+ AttentionBlock(
653
+ ch,
654
+ use_checkpoint=use_checkpoint,
655
+ num_heads=num_heads_upsample,
656
+ num_head_channels=dim_head,
657
+ use_new_attention_order=use_new_attention_order,
658
+ ) if not use_spatial_transformer else SpatialTransformer(
659
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
660
+ )
661
+ )
662
+ if level and i == num_res_blocks:
663
+ out_ch = ch
664
+ layers.append(
665
+ ResBlock(
666
+ ch,
667
+ time_embed_dim,
668
+ dropout,
669
+ out_channels=out_ch,
670
+ dims=dims,
671
+ use_checkpoint=use_checkpoint,
672
+ use_scale_shift_norm=use_scale_shift_norm,
673
+ up=True,
674
+ )
675
+ if resblock_updown
676
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
677
+ )
678
+ ds //= 2
679
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
680
+ self._feature_size += ch
681
+
682
+ self.out = nn.Sequential(
683
+ normalization(ch),
684
+ nn.SiLU(),
685
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
686
+ )
687
+ if self.predict_codebook_ids:
688
+ self.id_predictor = nn.Sequential(
689
+ normalization(ch),
690
+ conv_nd(dims, model_channels, n_embed, 1),
691
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
692
+ )
693
+
694
+ def convert_to_fp16(self):
695
+ """
696
+ Convert the torso of the model to float16.
697
+ """
698
+ self.input_blocks.apply(convert_module_to_f16)
699
+ self.middle_block.apply(convert_module_to_f16)
700
+ self.output_blocks.apply(convert_module_to_f16)
701
+
702
+ def convert_to_fp32(self):
703
+ """
704
+ Convert the torso of the model to float32.
705
+ """
706
+ self.input_blocks.apply(convert_module_to_f32)
707
+ self.middle_block.apply(convert_module_to_f32)
708
+ self.output_blocks.apply(convert_module_to_f32)
709
+
710
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
711
+ """
712
+ Apply the model to an input batch.
713
+ :param x: an [N x C x ...] Tensor of inputs.
714
+ :param timesteps: a 1-D batch of timesteps.
715
+ :param context: conditioning plugged in via crossattn
716
+ :param y: an [N] Tensor of labels, if class-conditional.
717
+ :return: an [N x C x ...] Tensor of outputs.
718
+ """
719
+ assert (y is not None) == (
720
+ self.num_classes is not None
721
+ ), "must specify y if and only if the model is class-conditional"
722
+ hs = []
723
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
724
+ emb = self.time_embed(t_emb)
725
+
726
+ if self.num_classes is not None:
727
+ assert y.shape == (x.shape[0],)
728
+ emb = emb + self.label_emb(y)
729
+
730
+ h = x.type(self.dtype)
731
+ for module in self.input_blocks:
732
+ h = module(h, emb, context)
733
+ hs.append(h)
734
+ h = self.middle_block(h, emb, context)
735
+ for module in self.output_blocks:
736
+ h = th.cat([h, hs.pop()], dim=1)
737
+ h = module(h, emb, context)
738
+ h = h.type(x.dtype)
739
+ if self.predict_codebook_ids:
740
+ return self.id_predictor(h)
741
+ else:
742
+ outp = self.out(h)
743
+ # print('summ ', outp.sum())
744
+ return outp
745
+
746
+
747
+ class EncoderUNetModel(nn.Module):
748
+ """
749
+ The half UNet model with attention and timestep embedding.
750
+ For usage, see UNet.
751
+ """
752
+
753
+ def __init__(
754
+ self,
755
+ image_size,
756
+ in_channels,
757
+ model_channels,
758
+ out_channels,
759
+ num_res_blocks,
760
+ attention_resolutions,
761
+ dropout=0,
762
+ channel_mult=(1, 2, 4, 8),
763
+ conv_resample=True,
764
+ dims=2,
765
+ use_checkpoint=False,
766
+ use_fp16=False,
767
+ num_heads=1,
768
+ num_head_channels=-1,
769
+ num_heads_upsample=-1,
770
+ use_scale_shift_norm=False,
771
+ resblock_updown=False,
772
+ use_new_attention_order=False,
773
+ pool="adaptive",
774
+ *args,
775
+ **kwargs
776
+ ):
777
+ super().__init__()
778
+
779
+ if num_heads_upsample == -1:
780
+ num_heads_upsample = num_heads
781
+
782
+ self.in_channels = in_channels
783
+ self.model_channels = model_channels
784
+ self.out_channels = out_channels
785
+ self.num_res_blocks = num_res_blocks
786
+ self.attention_resolutions = attention_resolutions
787
+ self.dropout = dropout
788
+ self.channel_mult = channel_mult
789
+ self.conv_resample = conv_resample
790
+ self.use_checkpoint = use_checkpoint
791
+ self.dtype = th.float16 if use_fp16 else th.float32
792
+ self.num_heads = num_heads
793
+ self.num_head_channels = num_head_channels
794
+ self.num_heads_upsample = num_heads_upsample
795
+
796
+ time_embed_dim = model_channels * 4
797
+ self.time_embed = nn.Sequential(
798
+ linear(model_channels, time_embed_dim),
799
+ nn.SiLU(),
800
+ linear(time_embed_dim, time_embed_dim),
801
+ )
802
+
803
+ self.input_blocks = nn.ModuleList(
804
+ [
805
+ TimestepEmbedSequential(
806
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
807
+ )
808
+ ]
809
+ )
810
+ self._feature_size = model_channels
811
+ input_block_chans = [model_channels]
812
+ ch = model_channels
813
+ ds = 1
814
+ for level, mult in enumerate(channel_mult):
815
+ for _ in range(num_res_blocks):
816
+ layers = [
817
+ ResBlock(
818
+ ch,
819
+ time_embed_dim,
820
+ dropout,
821
+ out_channels=mult * model_channels,
822
+ dims=dims,
823
+ use_checkpoint=use_checkpoint,
824
+ use_scale_shift_norm=use_scale_shift_norm,
825
+ )
826
+ ]
827
+ ch = mult * model_channels
828
+ if ds in attention_resolutions:
829
+ layers.append(
830
+ AttentionBlock(
831
+ ch,
832
+ use_checkpoint=use_checkpoint,
833
+ num_heads=num_heads,
834
+ num_head_channels=num_head_channels,
835
+ use_new_attention_order=use_new_attention_order,
836
+ )
837
+ )
838
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
839
+ self._feature_size += ch
840
+ input_block_chans.append(ch)
841
+ if level != len(channel_mult) - 1:
842
+ out_ch = ch
843
+ self.input_blocks.append(
844
+ TimestepEmbedSequential(
845
+ ResBlock(
846
+ ch,
847
+ time_embed_dim,
848
+ dropout,
849
+ out_channels=out_ch,
850
+ dims=dims,
851
+ use_checkpoint=use_checkpoint,
852
+ use_scale_shift_norm=use_scale_shift_norm,
853
+ down=True,
854
+ )
855
+ if resblock_updown
856
+ else Downsample(
857
+ ch, conv_resample, dims=dims, out_channels=out_ch
858
+ )
859
+ )
860
+ )
861
+ ch = out_ch
862
+ input_block_chans.append(ch)
863
+ ds *= 2
864
+ self._feature_size += ch
865
+
866
+ self.middle_block = TimestepEmbedSequential(
867
+ ResBlock(
868
+ ch,
869
+ time_embed_dim,
870
+ dropout,
871
+ dims=dims,
872
+ use_checkpoint=use_checkpoint,
873
+ use_scale_shift_norm=use_scale_shift_norm,
874
+ ),
875
+ AttentionBlock(
876
+ ch,
877
+ use_checkpoint=use_checkpoint,
878
+ num_heads=num_heads,
879
+ num_head_channels=num_head_channels,
880
+ use_new_attention_order=use_new_attention_order,
881
+ ),
882
+ ResBlock(
883
+ ch,
884
+ time_embed_dim,
885
+ dropout,
886
+ dims=dims,
887
+ use_checkpoint=use_checkpoint,
888
+ use_scale_shift_norm=use_scale_shift_norm,
889
+ ),
890
+ )
891
+ self._feature_size += ch
892
+ self.pool = pool
893
+ if pool == "adaptive":
894
+ self.out = nn.Sequential(
895
+ normalization(ch),
896
+ nn.SiLU(),
897
+ nn.AdaptiveAvgPool2d((1, 1)),
898
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
899
+ nn.Flatten(),
900
+ )
901
+ elif pool == "attention":
902
+ assert num_head_channels != -1
903
+ self.out = nn.Sequential(
904
+ normalization(ch),
905
+ nn.SiLU(),
906
+ AttentionPool2d(
907
+ (image_size // ds), ch, num_head_channels, out_channels
908
+ ),
909
+ )
910
+ elif pool == "spatial":
911
+ self.out = nn.Sequential(
912
+ nn.Linear(self._feature_size, 2048),
913
+ nn.ReLU(),
914
+ nn.Linear(2048, self.out_channels),
915
+ )
916
+ elif pool == "spatial_v2":
917
+ self.out = nn.Sequential(
918
+ nn.Linear(self._feature_size, 2048),
919
+ normalization(2048),
920
+ nn.SiLU(),
921
+ nn.Linear(2048, self.out_channels),
922
+ )
923
+ else:
924
+ raise NotImplementedError(f"Unexpected {pool} pooling")
925
+
926
+ def convert_to_fp16(self):
927
+ """
928
+ Convert the torso of the model to float16.
929
+ """
930
+ self.input_blocks.apply(convert_module_to_f16)
931
+ self.middle_block.apply(convert_module_to_f16)
932
+
933
+ def convert_to_fp32(self):
934
+ """
935
+ Convert the torso of the model to float32.
936
+ """
937
+ self.input_blocks.apply(convert_module_to_f32)
938
+ self.middle_block.apply(convert_module_to_f32)
939
+
940
+ def forward(self, x, timesteps):
941
+ """
942
+ Apply the model to an input batch.
943
+ :param x: an [N x C x ...] Tensor of inputs.
944
+ :param timesteps: a 1-D batch of timesteps.
945
+ :return: an [N x K] Tensor of outputs.
946
+ """
947
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
948
+
949
+ results = []
950
+ h = x.type(self.dtype)
951
+ for module in self.input_blocks:
952
+ h = module(h, emb)
953
+ if self.pool.startswith("spatial"):
954
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
955
+ h = self.middle_block(h, emb)
956
+ if self.pool.startswith("spatial"):
957
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
958
+ h = th.cat(results, axis=-1)
959
+ return self.out(h)
960
+ else:
961
+ h = h.type(x.dtype)
962
+ return self.out(h)
963
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+
126
+ with torch.no_grad():
127
+ output_tensors = ctx.run_function(*ctx.input_tensors)
128
+ return output_tensors
129
+
130
+ @staticmethod
131
+ def backward(ctx, *output_grads):
132
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
+ with torch.enable_grad():
134
+ # Fixes a bug where the first op in run_function modifies the
135
+ # Tensor storage in place, which is not allowed for detach()'d
136
+ # Tensors.
137
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138
+ output_tensors = ctx.run_function(*shallow_copies)
139
+ input_grads = torch.autograd.grad(
140
+ output_tensors,
141
+ ctx.input_tensors + ctx.input_params,
142
+ output_grads,
143
+ allow_unused=True,
144
+ )
145
+ del ctx.input_tensors
146
+ del ctx.input_params
147
+ del output_tensors
148
+ return (None, None) + input_grads
149
+
150
+
151
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152
+ """
153
+ Create sinusoidal timestep embeddings.
154
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
155
+ These may be fractional.
156
+ :param dim: the dimension of the output.
157
+ :param max_period: controls the minimum frequency of the embeddings.
158
+ :return: an [N x dim] Tensor of positional embeddings.
159
+ """
160
+ if not repeat_only:
161
+ half = dim // 2
162
+ freqs = torch.exp(
163
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164
+ ).to(device=timesteps.device)
165
+ args = timesteps[:, None].float() * freqs[None]
166
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167
+ if dim % 2:
168
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169
+ else:
170
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
171
+ return embedding
172
+
173
+
174
+ def zero_module(module):
175
+ """
176
+ Zero out the parameters of a module and return it.
177
+ """
178
+ for p in module.parameters():
179
+ p.detach().zero_()
180
+ return module
181
+
182
+
183
+ def scale_module(module, scale):
184
+ """
185
+ Scale the parameters of a module and return it.
186
+ """
187
+ for p in module.parameters():
188
+ p.detach().mul_(scale)
189
+ return module
190
+
191
+
192
+ def mean_flat(tensor):
193
+ """
194
+ Take the mean over all non-batch dimensions.
195
+ """
196
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
197
+
198
+
199
+ def normalization(channels):
200
+ """
201
+ Make a standard normalization layer.
202
+ :param channels: number of input channels.
203
+ :return: an nn.Module for normalization.
204
+ """
205
+ return GroupNorm32(32, channels)
206
+
207
+
208
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209
+ class SiLU(nn.Module):
210
+ def forward(self, x):
211
+ return x * torch.sigmoid(x)
212
+
213
+
214
+ class GroupNorm32(nn.GroupNorm):
215
+ def forward(self, x):
216
+ return super().forward(x.float()).type(x.dtype)
217
+
218
+ def conv_nd(dims, *args, **kwargs):
219
+ """
220
+ Create a 1D, 2D, or 3D convolution module.
221
+ """
222
+ if dims == 1:
223
+ return nn.Conv1d(*args, **kwargs)
224
+ elif dims == 2:
225
+ return nn.Conv2d(*args, **kwargs)
226
+ elif dims == 3:
227
+ return nn.Conv3d(*args, **kwargs)
228
+ raise ValueError(f"unsupported dimensions: {dims}")
229
+
230
+
231
+ def linear(*args, **kwargs):
232
+ """
233
+ Create a linear module.
234
+ """
235
+ return nn.Linear(*args, **kwargs)
236
+
237
+
238
+ def avg_pool_nd(dims, *args, **kwargs):
239
+ """
240
+ Create a 1D, 2D, or 3D average pooling module.
241
+ """
242
+ if dims == 1:
243
+ return nn.AvgPool1d(*args, **kwargs)
244
+ elif dims == 2:
245
+ return nn.AvgPool2d(*args, **kwargs)
246
+ elif dims == 3:
247
+ return nn.AvgPool3d(*args, **kwargs)
248
+ raise ValueError(f"unsupported dimensions: {dims}")
249
+
250
+
251
+ class HybridConditioner(nn.Module):
252
+
253
+ def __init__(self, c_concat_config, c_crossattn_config):
254
+ super().__init__()
255
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
256
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257
+
258
+ def forward(self, c_concat, c_crossattn):
259
+ c_concat = self.concat_conditioner(c_concat)
260
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
261
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262
+
263
+
264
+ def noise_like(shape, device, repeat=False):
265
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266
+ noise = lambda: torch.randn(shape, device=device)
267
+ return repeat_noise() if repeat else noise()
ldm/modules/discriminator/__pycache__/model.cpython-38.pyc ADDED
Binary file (2.32 kB). View file
 
ldm/modules/discriminator/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import functools
4
+ import torch.nn as nn
5
+
6
+
7
+ from ldm.modules.util import ActNorm
8
+
9
+
10
+ def weights_init(m):
11
+ classname = m.__class__.__name__
12
+ if classname.find('Conv') != -1:
13
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
14
+ elif classname.find('BatchNorm') != -1:
15
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
16
+ nn.init.constant_(m.bias.data, 0)
17
+
18
+
19
+ class NLayerDiscriminator(nn.Module):
20
+ """Defines a PatchGAN discriminator as in Pix2Pix
21
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
22
+ """
23
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
24
+ """Construct a PatchGAN discriminator
25
+ Parameters:
26
+ input_nc (int) -- the number of channels in input images
27
+ ndf (int) -- the number of filters in the last conv layer
28
+ n_layers (int) -- the number of conv layers in the discriminator
29
+ norm_layer -- normalization layer
30
+ """
31
+ super(NLayerDiscriminator, self).__init__()
32
+ if not use_actnorm:
33
+ norm_layer = nn.BatchNorm2d
34
+ else:
35
+ norm_layer = ActNorm
36
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
37
+ use_bias = norm_layer.func != nn.BatchNorm2d
38
+ else:
39
+ use_bias = norm_layer != nn.BatchNorm2d
40
+
41
+ kw = 4
42
+ padw = 1
43
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
44
+ nf_mult = 1
45
+ nf_mult_prev = 1
46
+ for n in range(1, n_layers): # gradually increase the number of filters
47
+ nf_mult_prev = nf_mult
48
+ nf_mult = min(2 ** n, 8)
49
+ sequence += [
50
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
51
+ norm_layer(ndf * nf_mult),
52
+ nn.LeakyReLU(0.2, True)
53
+ ]
54
+
55
+ nf_mult_prev = nf_mult
56
+ nf_mult = min(2 ** n_layers, 8)
57
+ sequence += [
58
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
59
+ norm_layer(ndf * nf_mult),
60
+ nn.LeakyReLU(0.2, True)
61
+ ]
62
+
63
+ sequence += [
64
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
65
+ self.main = nn.Sequential(*sequence)
66
+
67
+ def forward(self, input):
68
+ """Standard forward."""
69
+ return self.main(input)
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (158 Bytes). View file
 
ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc ADDED
Binary file (3.8 kB). View file
 
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (153 Bytes). View file
 
ldm/modules/encoders/__pycache__/modules.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ import clip
5
+ import open_clip
6
+ from einops import rearrange, repeat
7
+ from transformers import CLIPTokenizer, CLIPTextModel
8
+ # import kornia
9
+
10
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
11
+ import os
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
14
+
15
+
16
+ class AbstractEncoder(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def encode(self, *args, **kwargs):
21
+ raise NotImplementedError
22
+
23
+
24
+
25
+ class ClassEmbedder(nn.Module):
26
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
27
+ super().__init__()
28
+ self.key = key
29
+ self.embedding = nn.Embedding(n_classes, embed_dim)
30
+
31
+ def forward(self, batch, key=None):
32
+ if key is None:
33
+ key = self.key
34
+ # this is for use in crossattn
35
+ c = batch[key][:, None]
36
+ c = self.embedding(c)
37
+ return c
38
+
39
+ class HeirClassEmbedder(nn.Module):
40
+ def __init__(self, embed_dim, n_classes=[3, 6, 9, 38], key='class', device='cuda'):
41
+ super().__init__()
42
+ assert embed_dim % len(n_classes) == 0
43
+ self.key = key
44
+ self.device = device
45
+ self.embed_heir_dim = embed_dim//len(n_classes)
46
+ self.embedding_layers = []
47
+ self.embedding_level0 = nn.Embedding(n_classes[0], self.embed_heir_dim)
48
+ self.embedding_level1 = nn.Embedding(n_classes[1], self.embed_heir_dim)
49
+ self.embedding_level2 = nn.Embedding(n_classes[2], self.embed_heir_dim)
50
+ self.embedding_level3 = nn.Embedding(n_classes[3], self.embed_heir_dim)
51
+ # for i in list(n_classes):
52
+ # embedding = nn.Embedding(i, self.embed_heir_dim)
53
+ # self.embedding_layers.append(embedding)
54
+
55
+ def forward(self, batch, key=None):
56
+ if key is None:
57
+ key = self.key
58
+ # this is for use in crossattn
59
+ batch_size = len(batch[key][0])
60
+ heir_classes = batch[key]
61
+ # heir_classes_list = []
62
+ # for s in heir_classes:
63
+ # numbers = s.split(', ')
64
+ # heir_classes_list.extend(int(num) for num in numbers)
65
+ heir_classes = [[int(num) for num in item.split(', ')] for item in heir_classes[0]]
66
+ transformed_list = [list(pair) for pair in zip(*heir_classes)]
67
+ tensor_list = [torch.tensor(sublist).to(self.device) for sublist in transformed_list]
68
+ tensor_reshaped = [torch.reshape(sublist, (batch_size, 1)) for sublist in tensor_list]
69
+
70
+ embedding_list = [self.embedding_level0(tensor_reshaped[0]), self.embedding_level1(tensor_reshaped[1]),
71
+ self.embedding_level2(tensor_reshaped[2]), self.embedding_level3(tensor_reshaped[3])]
72
+
73
+
74
+
75
+ # embedding = []
76
+ # for i, classes in enumerate(heir_classes):
77
+ # embedding.append(self.embedding_layers[i](classes))
78
+ embedding = torch.cat(embedding_list, dim=-1)
79
+ return embedding
80
+
81
+
82
+ class HeirClassEmbedderMultiLevel(nn.Module):
83
+ def __init__(self, embed_dim, n_classes=[3, 6, 9, 38], key='class', device='cuda'):
84
+ super().__init__()
85
+ assert embed_dim % len(n_classes) == 0
86
+ self.key = key
87
+ self.device = device
88
+ self.n_classes = n_classes
89
+ self.embed_heir_dim = embed_dim//len(n_classes)
90
+ self.embedding_layers = []
91
+ self.embedding_level0 = nn.Embedding(n_classes[0], self.embed_heir_dim)
92
+ self.embedding_level1 = nn.Embedding(n_classes[1], self.embed_heir_dim)
93
+ self.embedding_level2 = nn.Embedding(n_classes[2], self.embed_heir_dim)
94
+ self.embedding_level3 = nn.Embedding(n_classes[3], self.embed_heir_dim)
95
+ # self.embedding_level4 = nn.Embedding(n_classes[4], self.embed_heir_dim)
96
+ # self.embedding_layers = []
97
+ self.embedding_layers = nn.ModuleList()
98
+ for i in list(n_classes):
99
+ embedding = nn.Embedding(i, self.embed_heir_dim)
100
+ self.embedding_layers.append(embedding.to(self.device))
101
+
102
+ # self.to(self.device)
103
+
104
+ def forward(self, batch, key=None):
105
+ if key is None:
106
+ key = self.key
107
+ # this is for use in crossattn
108
+ batch_size = len(batch[key][0])
109
+ hier_classes = batch[key]
110
+
111
+
112
+ # heir_classes_list = []
113
+ # for s in heir_classes:
114
+ # numbers = s.split(', ')
115
+ # heir_classes_list.extend(int(num) for num in numbers)
116
+
117
+
118
+ hier_classes = [[int(num) for num in item.split(', ')] for item in hier_classes[0]]
119
+ transformed_list = [list(pair) for pair in zip(*hier_classes)]
120
+ tensor_list = [torch.tensor(sublist).to(self.device) for sublist in transformed_list]
121
+ tensor_reshaped = [torch.reshape(sublist, (batch_size, 1)) for sublist in tensor_list]
122
+
123
+ embedding_list = []
124
+ for i in range(len(self.n_classes)):
125
+ embedding_list.append(self.embedding_layers[i](tensor_reshaped[i]))
126
+
127
+ # embedding_list_org = [self.embedding_level0(tensor_reshaped[0]), self.embedding_level1(tensor_reshaped[1]),
128
+ # self.embedding_level2(tensor_reshaped[2]), self.embedding_level3(tensor_reshaped[3]),
129
+ # self.embedding_level3(tensor_reshaped[4])]
130
+
131
+ # embedding_list_org = [self.embedding_level0(tensor_reshaped[0]), self.embedding_level1(tensor_reshaped[1]),
132
+ # self.embedding_level2(tensor_reshaped[2]), self.embedding_level3(tensor_reshaped[3])]
133
+
134
+ # embedding_org = torch.cat(embedding_list_org, dim=-1)
135
+
136
+ embedding = torch.cat(embedding_list, dim=-1)
137
+
138
+ return embedding
139
+
140
+ class TransformerEmbedder(AbstractEncoder):
141
+ """Some transformer encoder layers"""
142
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
143
+ super().__init__()
144
+ self.device = device
145
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
146
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
147
+
148
+ def forward(self, tokens):
149
+ tokens = tokens.to(self.device) # meh
150
+ z = self.transformer(tokens, return_embeddings=True)
151
+ return z
152
+
153
+ def encode(self, x):
154
+ return self(x)
155
+
156
+
157
+ class BERTTokenizer(AbstractEncoder):
158
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
159
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
160
+ super().__init__()
161
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
162
+ self.device = device
163
+ self.vq_interface = vq_interface
164
+ self.max_length = max_length
165
+
166
+ def forward(self, text):
167
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
168
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
169
+ tokens = batch_encoding["input_ids"].to(self.device)
170
+ return tokens
171
+
172
+ @torch.no_grad()
173
+ def encode(self, text):
174
+ tokens = self(text)
175
+ if not self.vq_interface:
176
+ return tokens
177
+ return None, None, [None, None, tokens]
178
+
179
+ def decode(self, text):
180
+ return text
181
+
182
+
183
+ class BERTEmbedderExtra(AbstractEncoder):
184
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
185
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
186
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
187
+ super().__init__()
188
+ self.use_tknz_fn = use_tokenizer
189
+ if self.use_tknz_fn:
190
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
191
+ self.device = device
192
+
193
+ special_tokens_dict = {'additional_special_tokens': ['<N>','<E>']}
194
+ num_added_toks = self.tknz_fn.tokenizer.add_special_tokens(special_tokens_dict)
195
+ vocab_size = len(self.tknz_fn.tokenizer)
196
+
197
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
198
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
199
+ emb_dropout=embedding_dropout)
200
+
201
+ def forward(self, text):
202
+ if self.use_tknz_fn:
203
+ tokens = self.tknz_fn(text)#.to(self.device)
204
+ else:
205
+ tokens = text
206
+ z = self.transformer(tokens, return_embeddings=True)
207
+ return z
208
+
209
+ def encode(self, text):
210
+ # output of length 77
211
+ return self(text)
212
+
213
+
214
+ class BERTEmbedder(AbstractEncoder):
215
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
216
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
217
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
218
+ super().__init__()
219
+ self.use_tknz_fn = use_tokenizer
220
+ if self.use_tknz_fn:
221
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
222
+ self.device = device
223
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
224
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
225
+ emb_dropout=embedding_dropout)
226
+
227
+ def forward(self, text):
228
+ if self.use_tknz_fn:
229
+ tokens = self.tknz_fn(text)#.to(self.device)
230
+ else:
231
+ tokens = text
232
+ z = self.transformer(tokens, return_embeddings=True)
233
+ return z
234
+
235
+ def encode(self, text):
236
+ # output of length 77
237
+ return self(text)
238
+
239
+
240
+ class SpatialRescaler(nn.Module):
241
+ def __init__(self,
242
+ n_stages=1,
243
+ method='bilinear',
244
+ multiplier=0.5,
245
+ in_channels=3,
246
+ out_channels=None,
247
+ bias=False):
248
+ super().__init__()
249
+ self.n_stages = n_stages
250
+ assert self.n_stages >= 0
251
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
252
+ self.multiplier = multiplier
253
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
254
+ self.remap_output = out_channels is not None
255
+ if self.remap_output:
256
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
257
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
258
+
259
+ def forward(self,x):
260
+ for stage in range(self.n_stages):
261
+ x = self.interpolator(x, scale_factor=self.multiplier)
262
+
263
+
264
+ if self.remap_output:
265
+ x = self.channel_mapper(x)
266
+ return x
267
+
268
+ def encode(self, x):
269
+ return self(x)
270
+
271
+ ### not using - hugging face implementation
272
+ class FrozenCLIPEmbedder(AbstractEncoder):
273
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
274
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
275
+ super().__init__()
276
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
277
+ self.transformer = CLIPTextModel.from_pretrained(version)
278
+ self.transformer.projection_dim = 512
279
+ self.device = device
280
+ self.max_length = max_length
281
+ self.freeze()
282
+
283
+ def freeze(self):
284
+ self.transformer = self.transformer.eval()
285
+ for param in self.parameters():
286
+ param.requires_grad = False
287
+
288
+ def forward(self, text):
289
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
290
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
291
+ tokens = batch_encoding["input_ids"].to(self.device)
292
+ outputs = self.transformer(input_ids=tokens)
293
+
294
+ z = outputs.last_hidden_state
295
+ # pooled_output = outputs.pooler_output
296
+ # return pooled_output
297
+ return z
298
+
299
+ def encode(self, text):
300
+ return self(text)
301
+
302
+ class FrozenCLIPTextEmbedder(nn.Module):
303
+ """
304
+ Uses the CLIP transformer encoder for text.
305
+ """
306
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
307
+ super().__init__()
308
+ self.model, _ = clip.load(version, jit=False, device="cpu")
309
+ self.device = device
310
+ self.max_length = max_length
311
+ self.n_repeat = n_repeat
312
+ self.normalize = normalize
313
+
314
+ def freeze(self):
315
+ self.model = self.model.eval()
316
+ for param in self.parameters():
317
+ param.requires_grad = False
318
+
319
+ def forward(self, text):
320
+ tokens = clip.tokenize(text).to(self.device)
321
+ z = self.model.encode_text(tokens)
322
+ if self.normalize:
323
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
324
+ return z
325
+
326
+ def encode(self, text):
327
+ z = self(text)
328
+ if z.ndim==2:
329
+ z = z[:, None, :]
330
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
331
+ return z
332
+
333
+ class FrozenBioClipTextEmbedder(nn.Module):
334
+ """
335
+ Uses the BioClip transformer encoder for text.
336
+ """
337
+ def __init__(self, version='hf-hub:imageomics/bioclip', device="cuda", max_length=77, n_repeat=1, normalize=True):
338
+ super().__init__()
339
+ # self.model, _ = open_clip.create_model_and_transforms(version, jit=False, device="cpu")
340
+ self.model, _, _ = open_clip.create_model_and_transforms(version)
341
+ self.model = self.model.eval()
342
+ self.model = self.model.to(device)
343
+ self.tokenizer = open_clip.get_tokenizer(version)
344
+ self.device = device
345
+ self.max_length = max_length
346
+ self.n_repeat = n_repeat
347
+ self.normalize = normalize
348
+
349
+ # model = model.eval()
350
+ # model = model.to(device)
351
+
352
+ def freeze(self):
353
+ self.model = self.model.eval()
354
+ for param in self.parameters():
355
+ param.requires_grad = False
356
+
357
+ def forward(self, text):
358
+ tokens = self.tokenizer(text).to(self.device)
359
+ z = self.model.encode_text(tokens)
360
+ if self.normalize:
361
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
362
+ return z
363
+
364
+ def encode(self, text):
365
+ z = self(text)
366
+ if z.ndim==2:
367
+ z = z[:, None, :]
368
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
369
+ return z
370
+
371
+
372
+ # class FrozenClipImageEmbedder(nn.Module):
373
+ # """
374
+ # Uses the CLIP image encoder.
375
+ # """
376
+ # def __init__(
377
+ # self,
378
+ # model,
379
+ # jit=False,
380
+ # device='cuda' if torch.cuda.is_available() else 'cpu',
381
+ # antialias=False,
382
+ # ):
383
+ # super().__init__()
384
+ # self.model, _ = clip.load(name=model, device=device, jit=jit)
385
+
386
+ # self.antialias = antialias
387
+
388
+ # self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
389
+ # self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
390
+
391
+ # def preprocess(self, x):
392
+ # # normalize to [0,1]
393
+ # x = kornia.geometry.resize(x, (224, 224),
394
+ # interpolation='bicubic',align_corners=True,
395
+ # antialias=self.antialias)
396
+ # x = (x + 1.) / 2.
397
+ # # renormalize according to clip
398
+ # x = kornia.enhance.normalize(x, self.mean, self.std)
399
+ # return x
400
+
401
+ # def forward(self, x):
402
+ # # x is assumed to be in range [-1,1]
403
+ # return self.model.encode_image(self.preprocess(x))
404
+
ldm/modules/image_degradation/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2
+ from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
ldm/modules/image_degradation/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (353 Bytes). View file
 
ldm/modules/image_degradation/__pycache__/bsrgan.cpython-38.pyc ADDED
Binary file (19.4 kB). View file
 
ldm/modules/image_degradation/__pycache__/bsrgan_light.cpython-38.pyc ADDED
Binary file (17.2 kB). View file
 
ldm/modules/image_degradation/__pycache__/utils_image.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
ldm/modules/image_degradation/bsrgan.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ # --------------------------------------------
4
+ # Super-Resolution
5
+ # --------------------------------------------
6
+ #
7
+ # Kai Zhang ([email protected])
8
+ # https://github.com/cszn
9
+ # From 2019/03--2021/08
10
+ # --------------------------------------------
11
+ """
12
+
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+
17
+ from functools import partial
18
+ import random
19
+ from scipy import ndimage
20
+ import scipy
21
+ import scipy.stats as ss
22
+ from scipy.interpolate import interp2d
23
+ from scipy.linalg import orth
24
+ import albumentations
25
+
26
+ import ldm.modules.image_degradation.utils_image as util
27
+
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+ Return:
35
+ cropped image
36
+ '''
37
+ w, h = img.shape[:2]
38
+ im = np.copy(img)
39
+ return im[:w - w % sf, :h - h % sf, ...]
40
+
41
+
42
+ """
43
+ # --------------------------------------------
44
+ # anisotropic Gaussian kernels
45
+ # --------------------------------------------
46
+ """
47
+
48
+
49
+ def analytic_kernel(k):
50
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
51
+ k_size = k.shape[0]
52
+ # Calculate the big kernels size
53
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
54
+ # Loop over the small kernel to fill the big one
55
+ for r in range(k_size):
56
+ for c in range(k_size):
57
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
58
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
59
+ crop = k_size // 2
60
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
61
+ # Normalize to 1
62
+ return cropped_big_k / cropped_big_k.sum()
63
+
64
+
65
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
66
+ """ generate an anisotropic Gaussian kernel
67
+ Args:
68
+ ksize : e.g., 15, kernel size
69
+ theta : [0, pi], rotation angle range
70
+ l1 : [0.1,50], scaling of eigenvalues
71
+ l2 : [0.1,l1], scaling of eigenvalues
72
+ If l1 = l2, will get an isotropic Gaussian kernel.
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf - 1) * 0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w - 1)
117
+ y1 = np.clip(y1, 0, h - 1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1, c, 1, 1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
146
+ """"
147
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
148
+ # Kai Zhang
149
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
150
+ # max_var = 2.5 * sf
151
+ """
152
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
153
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
154
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
155
+ theta = np.random.rand() * np.pi # random theta
156
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
157
+
158
+ # Set COV matrix using Lambdas and Theta
159
+ LAMBDA = np.diag([lambda_1, lambda_2])
160
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
161
+ [np.sin(theta), np.cos(theta)]])
162
+ SIGMA = Q @ LAMBDA @ Q.T
163
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
164
+
165
+ # Set expectation position (shifting kernel for aligned image)
166
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
167
+ MU = MU[None, None, :, None]
168
+
169
+ # Create meshgrid for Gaussian
170
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
171
+ Z = np.stack([X, Y], 2)[:, :, :, None]
172
+
173
+ # Calcualte Gaussian for every pixel of the kernel
174
+ ZZ = Z - MU
175
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
176
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
177
+
178
+ # shift the kernel so it will be centered
179
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
180
+
181
+ # Normalize the kernel and return
182
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
183
+ kernel = raw_kernel / np.sum(raw_kernel)
184
+ return kernel
185
+
186
+
187
+ def fspecial_gaussian(hsize, sigma):
188
+ hsize = [hsize, hsize]
189
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
190
+ std = sigma
191
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
192
+ arg = -(x * x + y * y) / (2 * std * std)
193
+ h = np.exp(arg)
194
+ h[h < scipy.finfo(float).eps * h.max()] = 0
195
+ sumh = h.sum()
196
+ if sumh != 0:
197
+ h = h / sumh
198
+ return h
199
+
200
+
201
+ def fspecial_laplacian(alpha):
202
+ alpha = max([0, min([alpha, 1])])
203
+ h1 = alpha / (alpha + 1)
204
+ h2 = (1 - alpha) / (alpha + 1)
205
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
206
+ h = np.array(h)
207
+ return h
208
+
209
+
210
+ def fspecial(filter_type, *args, **kwargs):
211
+ '''
212
+ python code from:
213
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
214
+ '''
215
+ if filter_type == 'gaussian':
216
+ return fspecial_gaussian(*args, **kwargs)
217
+ if filter_type == 'laplacian':
218
+ return fspecial_laplacian(*args, **kwargs)
219
+
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+ Return:
234
+ bicubicly downsampled LR image
235
+ '''
236
+ x = util.imresize_np(x, scale=1 / sf)
237
+ return x
238
+
239
+
240
+ def srmd_degradation(x, k, sf=3):
241
+ ''' blur + bicubic downsampling
242
+ Args:
243
+ x: HxWxC image, [0, 1]
244
+ k: hxw, double
245
+ sf: down-scale factor
246
+ Return:
247
+ downsampled LR image
248
+ Reference:
249
+ @inproceedings{zhang2018learning,
250
+ title={Learning a single convolutional super-resolution network for multiple degradations},
251
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
252
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
253
+ pages={3262--3271},
254
+ year={2018}
255
+ }
256
+ '''
257
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
258
+ x = bicubic_degradation(x, sf=sf)
259
+ return x
260
+
261
+
262
+ def dpsr_degradation(x, k, sf=3):
263
+ ''' bicubic downsampling + blur
264
+ Args:
265
+ x: HxWxC image, [0, 1]
266
+ k: hxw, double
267
+ sf: down-scale factor
268
+ Return:
269
+ downsampled LR image
270
+ Reference:
271
+ @inproceedings{zhang2019deep,
272
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
273
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
274
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
275
+ pages={1671--1681},
276
+ year={2019}
277
+ }
278
+ '''
279
+ x = bicubic_degradation(x, sf=sf)
280
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
281
+ return x
282
+
283
+
284
+ def classical_degradation(x, k, sf=3):
285
+ ''' blur + downsampling
286
+ Args:
287
+ x: HxWxC image, [0, 1]/[0, 255]
288
+ k: hxw, double
289
+ sf: down-scale factor
290
+ Return:
291
+ downsampled LR image
292
+ '''
293
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
294
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
295
+ st = 0
296
+ return x[st::sf, st::sf, ...]
297
+
298
+
299
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
300
+ """USM sharpening. borrowed from real-ESRGAN
301
+ Input image: I; Blurry image: B.
302
+ 1. K = I + weight * (I - B)
303
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
304
+ 3. Blur mask:
305
+ 4. Out = Mask * K + (1 - Mask) * I
306
+ Args:
307
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
308
+ weight (float): Sharp weight. Default: 1.
309
+ radius (float): Kernel size of Gaussian blur. Default: 50.
310
+ threshold (int):
311
+ """
312
+ if radius % 2 == 0:
313
+ radius += 1
314
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
315
+ residual = img - blur
316
+ mask = np.abs(residual) * 255 > threshold
317
+ mask = mask.astype('float32')
318
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
319
+
320
+ K = img + weight * residual
321
+ K = np.clip(K, 0, 1)
322
+ return soft_mask * K + (1 - soft_mask) * img
323
+
324
+
325
+ def add_blur(img, sf=4):
326
+ wd2 = 4.0 + sf
327
+ wd = 2.0 + 0.2 * sf
328
+ if random.random() < 0.5:
329
+ l1 = wd2 * random.random()
330
+ l2 = wd2 * random.random()
331
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
332
+ else:
333
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
334
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
335
+
336
+ return img
337
+
338
+
339
+ def add_resize(img, sf=4):
340
+ rnum = np.random.rand()
341
+ if rnum > 0.8: # up
342
+ sf1 = random.uniform(1, 2)
343
+ elif rnum < 0.7: # down
344
+ sf1 = random.uniform(0.5 / sf, 1)
345
+ else:
346
+ sf1 = 1.0
347
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
348
+ img = np.clip(img, 0.0, 1.0)
349
+
350
+ return img
351
+
352
+
353
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
354
+ # noise_level = random.randint(noise_level1, noise_level2)
355
+ # rnum = np.random.rand()
356
+ # if rnum > 0.6: # add color Gaussian noise
357
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
358
+ # elif rnum < 0.4: # add grayscale Gaussian noise
359
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
360
+ # else: # add noise
361
+ # L = noise_level2 / 255.
362
+ # D = np.diag(np.random.rand(3))
363
+ # U = orth(np.random.rand(3, 3))
364
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
365
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
366
+ # img = np.clip(img, 0.0, 1.0)
367
+ # return img
368
+
369
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
370
+ noise_level = random.randint(noise_level1, noise_level2)
371
+ rnum = np.random.rand()
372
+ if rnum > 0.6: # add color Gaussian noise
373
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
374
+ elif rnum < 0.4: # add grayscale Gaussian noise
375
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
376
+ else: # add noise
377
+ L = noise_level2 / 255.
378
+ D = np.diag(np.random.rand(3))
379
+ U = orth(np.random.rand(3, 3))
380
+ conv = np.dot(np.dot(np.transpose(U), D), U)
381
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
382
+ img = np.clip(img, 0.0, 1.0)
383
+ return img
384
+
385
+
386
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
387
+ noise_level = random.randint(noise_level1, noise_level2)
388
+ img = np.clip(img, 0.0, 1.0)
389
+ rnum = random.random()
390
+ if rnum > 0.6:
391
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
392
+ elif rnum < 0.4:
393
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
394
+ else:
395
+ L = noise_level2 / 255.
396
+ D = np.diag(np.random.rand(3))
397
+ U = orth(np.random.rand(3, 3))
398
+ conv = np.dot(np.dot(np.transpose(U), D), U)
399
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
400
+ img = np.clip(img, 0.0, 1.0)
401
+ return img
402
+
403
+
404
+ def add_Poisson_noise(img):
405
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
406
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
407
+ if random.random() < 0.5:
408
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
409
+ else:
410
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
411
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
412
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
413
+ img += noise_gray[:, :, np.newaxis]
414
+ img = np.clip(img, 0.0, 1.0)
415
+ return img
416
+
417
+
418
+ def add_JPEG_noise(img):
419
+ quality_factor = random.randint(30, 95)
420
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
421
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
422
+ img = cv2.imdecode(encimg, 1)
423
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
424
+ return img
425
+
426
+
427
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
428
+ h, w = lq.shape[:2]
429
+ rnd_h = random.randint(0, h - lq_patchsize)
430
+ rnd_w = random.randint(0, w - lq_patchsize)
431
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
432
+
433
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
434
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
435
+ return lq, hq
436
+
437
+
438
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
439
+ """
440
+ This is the degradation model of BSRGAN from the paper
441
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
442
+ ----------
443
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
444
+ sf: scale factor
445
+ isp_model: camera ISP model
446
+ Returns
447
+ -------
448
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
449
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
450
+ """
451
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
452
+ sf_ori = sf
453
+
454
+ h1, w1 = img.shape[:2]
455
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
456
+ h, w = img.shape[:2]
457
+
458
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
459
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
460
+
461
+ hq = img.copy()
462
+
463
+ if sf == 4 and random.random() < scale2_prob: # downsample1
464
+ if np.random.rand() < 0.5:
465
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
466
+ interpolation=random.choice([1, 2, 3]))
467
+ else:
468
+ img = util.imresize_np(img, 1 / 2, True)
469
+ img = np.clip(img, 0.0, 1.0)
470
+ sf = 2
471
+
472
+ shuffle_order = random.sample(range(7), 7)
473
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
474
+ if idx1 > idx2: # keep downsample3 last
475
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
476
+
477
+ for i in shuffle_order:
478
+
479
+ if i == 0:
480
+ img = add_blur(img, sf=sf)
481
+
482
+ elif i == 1:
483
+ img = add_blur(img, sf=sf)
484
+
485
+ elif i == 2:
486
+ a, b = img.shape[1], img.shape[0]
487
+ # downsample2
488
+ if random.random() < 0.75:
489
+ sf1 = random.uniform(1, 2 * sf)
490
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
491
+ interpolation=random.choice([1, 2, 3]))
492
+ else:
493
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
494
+ k_shifted = shift_pixel(k, sf)
495
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
496
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
497
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
498
+ img = np.clip(img, 0.0, 1.0)
499
+
500
+ elif i == 3:
501
+ # downsample3
502
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
503
+ img = np.clip(img, 0.0, 1.0)
504
+
505
+ elif i == 4:
506
+ # add Gaussian noise
507
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
508
+
509
+ elif i == 5:
510
+ # add JPEG noise
511
+ if random.random() < jpeg_prob:
512
+ img = add_JPEG_noise(img)
513
+
514
+ elif i == 6:
515
+ # add processed camera sensor noise
516
+ if random.random() < isp_prob and isp_model is not None:
517
+ with torch.no_grad():
518
+ img, hq = isp_model.forward(img.copy(), hq)
519
+
520
+ # add final JPEG compression noise
521
+ img = add_JPEG_noise(img)
522
+
523
+ # random crop
524
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
525
+
526
+ return img, hq
527
+
528
+
529
+ # todo no isp_model?
530
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
531
+ """
532
+ This is the degradation model of BSRGAN from the paper
533
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
534
+ ----------
535
+ sf: scale factor
536
+ isp_model: camera ISP model
537
+ Returns
538
+ -------
539
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
540
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
541
+ """
542
+ image = util.uint2single(image)
543
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
544
+ sf_ori = sf
545
+
546
+ h1, w1 = image.shape[:2]
547
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
548
+ h, w = image.shape[:2]
549
+
550
+ hq = image.copy()
551
+
552
+ if sf == 4 and random.random() < scale2_prob: # downsample1
553
+ if np.random.rand() < 0.5:
554
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
555
+ interpolation=random.choice([1, 2, 3]))
556
+ else:
557
+ image = util.imresize_np(image, 1 / 2, True)
558
+ image = np.clip(image, 0.0, 1.0)
559
+ sf = 2
560
+
561
+ shuffle_order = random.sample(range(7), 7)
562
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
563
+ if idx1 > idx2: # keep downsample3 last
564
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
565
+
566
+ for i in shuffle_order:
567
+
568
+ if i == 0:
569
+ image = add_blur(image, sf=sf)
570
+
571
+ elif i == 1:
572
+ image = add_blur(image, sf=sf)
573
+
574
+ elif i == 2:
575
+ a, b = image.shape[1], image.shape[0]
576
+ # downsample2
577
+ if random.random() < 0.75:
578
+ sf1 = random.uniform(1, 2 * sf)
579
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
580
+ interpolation=random.choice([1, 2, 3]))
581
+ else:
582
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
583
+ k_shifted = shift_pixel(k, sf)
584
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
585
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
586
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
587
+ image = np.clip(image, 0.0, 1.0)
588
+
589
+ elif i == 3:
590
+ # downsample3
591
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
592
+ image = np.clip(image, 0.0, 1.0)
593
+
594
+ elif i == 4:
595
+ # add Gaussian noise
596
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
597
+
598
+ elif i == 5:
599
+ # add JPEG noise
600
+ if random.random() < jpeg_prob:
601
+ image = add_JPEG_noise(image)
602
+
603
+ # elif i == 6:
604
+ # # add processed camera sensor noise
605
+ # if random.random() < isp_prob and isp_model is not None:
606
+ # with torch.no_grad():
607
+ # img, hq = isp_model.forward(img.copy(), hq)
608
+
609
+ # add final JPEG compression noise
610
+ image = add_JPEG_noise(image)
611
+ image = util.single2uint(image)
612
+ example = {"image":image}
613
+ return example
614
+
615
+
616
+ # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
617
+ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
618
+ """
619
+ This is an extended degradation model by combining
620
+ the degradation models of BSRGAN and Real-ESRGAN
621
+ ----------
622
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
623
+ sf: scale factor
624
+ use_shuffle: the degradation shuffle
625
+ use_sharp: sharpening the img
626
+ Returns
627
+ -------
628
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
629
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
630
+ """
631
+
632
+ h1, w1 = img.shape[:2]
633
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
634
+ h, w = img.shape[:2]
635
+
636
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
637
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
638
+
639
+ if use_sharp:
640
+ img = add_sharpening(img)
641
+ hq = img.copy()
642
+
643
+ if random.random() < shuffle_prob:
644
+ shuffle_order = random.sample(range(13), 13)
645
+ else:
646
+ shuffle_order = list(range(13))
647
+ # local shuffle for noise, JPEG is always the last one
648
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
649
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
650
+
651
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
652
+
653
+ for i in shuffle_order:
654
+ if i == 0:
655
+ img = add_blur(img, sf=sf)
656
+ elif i == 1:
657
+ img = add_resize(img, sf=sf)
658
+ elif i == 2:
659
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
660
+ elif i == 3:
661
+ if random.random() < poisson_prob:
662
+ img = add_Poisson_noise(img)
663
+ elif i == 4:
664
+ if random.random() < speckle_prob:
665
+ img = add_speckle_noise(img)
666
+ elif i == 5:
667
+ if random.random() < isp_prob and isp_model is not None:
668
+ with torch.no_grad():
669
+ img, hq = isp_model.forward(img.copy(), hq)
670
+ elif i == 6:
671
+ img = add_JPEG_noise(img)
672
+ elif i == 7:
673
+ img = add_blur(img, sf=sf)
674
+ elif i == 8:
675
+ img = add_resize(img, sf=sf)
676
+ elif i == 9:
677
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
678
+ elif i == 10:
679
+ if random.random() < poisson_prob:
680
+ img = add_Poisson_noise(img)
681
+ elif i == 11:
682
+ if random.random() < speckle_prob:
683
+ img = add_speckle_noise(img)
684
+ elif i == 12:
685
+ if random.random() < isp_prob and isp_model is not None:
686
+ with torch.no_grad():
687
+ img, hq = isp_model.forward(img.copy(), hq)
688
+ else:
689
+ print('check the shuffle!')
690
+
691
+ # resize to desired size
692
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
693
+ interpolation=random.choice([1, 2, 3]))
694
+
695
+ # add final JPEG compression noise
696
+ img = add_JPEG_noise(img)
697
+
698
+ # random crop
699
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
700
+
701
+ return img, hq
702
+
703
+
704
+ if __name__ == '__main__':
705
+ print("hey")
706
+ img = util.imread_uint('utils/test.png', 3)
707
+ print(img)
708
+ img = util.uint2single(img)
709
+ print(img)
710
+ img = img[:448, :448]
711
+ h = img.shape[0] // 4
712
+ print("resizing to", h)
713
+ sf = 4
714
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
715
+ for i in range(20):
716
+ print(i)
717
+ img_lq = deg_fn(img)
718
+ print(img_lq)
719
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
720
+ print(img_lq.shape)
721
+ print("bicubic", img_lq_bicubic.shape)
722
+ print(img_hq.shape)
723
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
724
+ interpolation=0)
725
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
726
+ interpolation=0)
727
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
728
+ util.imsave(img_concat, str(i) + '.png')
729
+
730
+
ldm/modules/image_degradation/bsrgan_light.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from functools import partial
7
+ import random
8
+ from scipy import ndimage
9
+ import scipy
10
+ import scipy.stats as ss
11
+ from scipy.interpolate import interp2d
12
+ from scipy.linalg import orth
13
+ import albumentations
14
+
15
+ import ldm.modules.image_degradation.utils_image as util
16
+
17
+ """
18
+ # --------------------------------------------
19
+ # Super-Resolution
20
+ # --------------------------------------------
21
+ #
22
+ # Kai Zhang ([email protected])
23
+ # https://github.com/cszn
24
+ # From 2019/03--2021/08
25
+ # --------------------------------------------
26
+ """
27
+
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+ Return:
35
+ cropped image
36
+ '''
37
+ w, h = img.shape[:2]
38
+ im = np.copy(img)
39
+ return im[:w - w % sf, :h - h % sf, ...]
40
+
41
+
42
+ """
43
+ # --------------------------------------------
44
+ # anisotropic Gaussian kernels
45
+ # --------------------------------------------
46
+ """
47
+
48
+
49
+ def analytic_kernel(k):
50
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
51
+ k_size = k.shape[0]
52
+ # Calculate the big kernels size
53
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
54
+ # Loop over the small kernel to fill the big one
55
+ for r in range(k_size):
56
+ for c in range(k_size):
57
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
58
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
59
+ crop = k_size // 2
60
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
61
+ # Normalize to 1
62
+ return cropped_big_k / cropped_big_k.sum()
63
+
64
+
65
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
66
+ """ generate an anisotropic Gaussian kernel
67
+ Args:
68
+ ksize : e.g., 15, kernel size
69
+ theta : [0, pi], rotation angle range
70
+ l1 : [0.1,50], scaling of eigenvalues
71
+ l2 : [0.1,l1], scaling of eigenvalues
72
+ If l1 = l2, will get an isotropic Gaussian kernel.
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf - 1) * 0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w - 1)
117
+ y1 = np.clip(y1, 0, h - 1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1, c, 1, 1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
146
+ """"
147
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
148
+ # Kai Zhang
149
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
150
+ # max_var = 2.5 * sf
151
+ """
152
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
153
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
154
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
155
+ theta = np.random.rand() * np.pi # random theta
156
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
157
+
158
+ # Set COV matrix using Lambdas and Theta
159
+ LAMBDA = np.diag([lambda_1, lambda_2])
160
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
161
+ [np.sin(theta), np.cos(theta)]])
162
+ SIGMA = Q @ LAMBDA @ Q.T
163
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
164
+
165
+ # Set expectation position (shifting kernel for aligned image)
166
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
167
+ MU = MU[None, None, :, None]
168
+
169
+ # Create meshgrid for Gaussian
170
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
171
+ Z = np.stack([X, Y], 2)[:, :, :, None]
172
+
173
+ # Calcualte Gaussian for every pixel of the kernel
174
+ ZZ = Z - MU
175
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
176
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
177
+
178
+ # shift the kernel so it will be centered
179
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
180
+
181
+ # Normalize the kernel and return
182
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
183
+ kernel = raw_kernel / np.sum(raw_kernel)
184
+ return kernel
185
+
186
+
187
+ def fspecial_gaussian(hsize, sigma):
188
+ hsize = [hsize, hsize]
189
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
190
+ std = sigma
191
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
192
+ arg = -(x * x + y * y) / (2 * std * std)
193
+ h = np.exp(arg)
194
+ h[h < scipy.finfo(float).eps * h.max()] = 0
195
+ sumh = h.sum()
196
+ if sumh != 0:
197
+ h = h / sumh
198
+ return h
199
+
200
+
201
+ def fspecial_laplacian(alpha):
202
+ alpha = max([0, min([alpha, 1])])
203
+ h1 = alpha / (alpha + 1)
204
+ h2 = (1 - alpha) / (alpha + 1)
205
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
206
+ h = np.array(h)
207
+ return h
208
+
209
+
210
+ def fspecial(filter_type, *args, **kwargs):
211
+ '''
212
+ python code from:
213
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
214
+ '''
215
+ if filter_type == 'gaussian':
216
+ return fspecial_gaussian(*args, **kwargs)
217
+ if filter_type == 'laplacian':
218
+ return fspecial_laplacian(*args, **kwargs)
219
+
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+ Return:
234
+ bicubicly downsampled LR image
235
+ '''
236
+ x = util.imresize_np(x, scale=1 / sf)
237
+ return x
238
+
239
+
240
+ def srmd_degradation(x, k, sf=3):
241
+ ''' blur + bicubic downsampling
242
+ Args:
243
+ x: HxWxC image, [0, 1]
244
+ k: hxw, double
245
+ sf: down-scale factor
246
+ Return:
247
+ downsampled LR image
248
+ Reference:
249
+ @inproceedings{zhang2018learning,
250
+ title={Learning a single convolutional super-resolution network for multiple degradations},
251
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
252
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
253
+ pages={3262--3271},
254
+ year={2018}
255
+ }
256
+ '''
257
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
258
+ x = bicubic_degradation(x, sf=sf)
259
+ return x
260
+
261
+
262
+ def dpsr_degradation(x, k, sf=3):
263
+ ''' bicubic downsampling + blur
264
+ Args:
265
+ x: HxWxC image, [0, 1]
266
+ k: hxw, double
267
+ sf: down-scale factor
268
+ Return:
269
+ downsampled LR image
270
+ Reference:
271
+ @inproceedings{zhang2019deep,
272
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
273
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
274
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
275
+ pages={1671--1681},
276
+ year={2019}
277
+ }
278
+ '''
279
+ x = bicubic_degradation(x, sf=sf)
280
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
281
+ return x
282
+
283
+
284
+ def classical_degradation(x, k, sf=3):
285
+ ''' blur + downsampling
286
+ Args:
287
+ x: HxWxC image, [0, 1]/[0, 255]
288
+ k: hxw, double
289
+ sf: down-scale factor
290
+ Return:
291
+ downsampled LR image
292
+ '''
293
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
294
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
295
+ st = 0
296
+ return x[st::sf, st::sf, ...]
297
+
298
+
299
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
300
+ """USM sharpening. borrowed from real-ESRGAN
301
+ Input image: I; Blurry image: B.
302
+ 1. K = I + weight * (I - B)
303
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
304
+ 3. Blur mask:
305
+ 4. Out = Mask * K + (1 - Mask) * I
306
+ Args:
307
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
308
+ weight (float): Sharp weight. Default: 1.
309
+ radius (float): Kernel size of Gaussian blur. Default: 50.
310
+ threshold (int):
311
+ """
312
+ if radius % 2 == 0:
313
+ radius += 1
314
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
315
+ residual = img - blur
316
+ mask = np.abs(residual) * 255 > threshold
317
+ mask = mask.astype('float32')
318
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
319
+
320
+ K = img + weight * residual
321
+ K = np.clip(K, 0, 1)
322
+ return soft_mask * K + (1 - soft_mask) * img
323
+
324
+
325
+ def add_blur(img, sf=4):
326
+ wd2 = 4.0 + sf
327
+ wd = 2.0 + 0.2 * sf
328
+
329
+ wd2 = wd2/4
330
+ wd = wd/4
331
+
332
+ if random.random() < 0.5:
333
+ l1 = wd2 * random.random()
334
+ l2 = wd2 * random.random()
335
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
336
+ else:
337
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
338
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
339
+
340
+ return img
341
+
342
+
343
+ def add_resize(img, sf=4):
344
+ rnum = np.random.rand()
345
+ if rnum > 0.8: # up
346
+ sf1 = random.uniform(1, 2)
347
+ elif rnum < 0.7: # down
348
+ sf1 = random.uniform(0.5 / sf, 1)
349
+ else:
350
+ sf1 = 1.0
351
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
352
+ img = np.clip(img, 0.0, 1.0)
353
+
354
+ return img
355
+
356
+
357
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
358
+ # noise_level = random.randint(noise_level1, noise_level2)
359
+ # rnum = np.random.rand()
360
+ # if rnum > 0.6: # add color Gaussian noise
361
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
362
+ # elif rnum < 0.4: # add grayscale Gaussian noise
363
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
364
+ # else: # add noise
365
+ # L = noise_level2 / 255.
366
+ # D = np.diag(np.random.rand(3))
367
+ # U = orth(np.random.rand(3, 3))
368
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
369
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
370
+ # img = np.clip(img, 0.0, 1.0)
371
+ # return img
372
+
373
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
374
+ noise_level = random.randint(noise_level1, noise_level2)
375
+ rnum = np.random.rand()
376
+ if rnum > 0.6: # add color Gaussian noise
377
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
378
+ elif rnum < 0.4: # add grayscale Gaussian noise
379
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
380
+ else: # add noise
381
+ L = noise_level2 / 255.
382
+ D = np.diag(np.random.rand(3))
383
+ U = orth(np.random.rand(3, 3))
384
+ conv = np.dot(np.dot(np.transpose(U), D), U)
385
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
386
+ img = np.clip(img, 0.0, 1.0)
387
+ return img
388
+
389
+
390
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
391
+ noise_level = random.randint(noise_level1, noise_level2)
392
+ img = np.clip(img, 0.0, 1.0)
393
+ rnum = random.random()
394
+ if rnum > 0.6:
395
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
396
+ elif rnum < 0.4:
397
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
398
+ else:
399
+ L = noise_level2 / 255.
400
+ D = np.diag(np.random.rand(3))
401
+ U = orth(np.random.rand(3, 3))
402
+ conv = np.dot(np.dot(np.transpose(U), D), U)
403
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
404
+ img = np.clip(img, 0.0, 1.0)
405
+ return img
406
+
407
+
408
+ def add_Poisson_noise(img):
409
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
410
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
411
+ if random.random() < 0.5:
412
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
413
+ else:
414
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
415
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
416
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
417
+ img += noise_gray[:, :, np.newaxis]
418
+ img = np.clip(img, 0.0, 1.0)
419
+ return img
420
+
421
+
422
+ def add_JPEG_noise(img):
423
+ quality_factor = random.randint(80, 95)
424
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
425
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
426
+ img = cv2.imdecode(encimg, 1)
427
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
428
+ return img
429
+
430
+
431
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
432
+ h, w = lq.shape[:2]
433
+ rnd_h = random.randint(0, h - lq_patchsize)
434
+ rnd_w = random.randint(0, w - lq_patchsize)
435
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
436
+
437
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
438
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
439
+ return lq, hq
440
+
441
+
442
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
443
+ """
444
+ This is the degradation model of BSRGAN from the paper
445
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
446
+ ----------
447
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
448
+ sf: scale factor
449
+ isp_model: camera ISP model
450
+ Returns
451
+ -------
452
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
453
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
454
+ """
455
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
456
+ sf_ori = sf
457
+
458
+ h1, w1 = img.shape[:2]
459
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
460
+ h, w = img.shape[:2]
461
+
462
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
463
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
464
+
465
+ hq = img.copy()
466
+
467
+ if sf == 4 and random.random() < scale2_prob: # downsample1
468
+ if np.random.rand() < 0.5:
469
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
470
+ interpolation=random.choice([1, 2, 3]))
471
+ else:
472
+ img = util.imresize_np(img, 1 / 2, True)
473
+ img = np.clip(img, 0.0, 1.0)
474
+ sf = 2
475
+
476
+ shuffle_order = random.sample(range(7), 7)
477
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
478
+ if idx1 > idx2: # keep downsample3 last
479
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
480
+
481
+ for i in shuffle_order:
482
+
483
+ if i == 0:
484
+ img = add_blur(img, sf=sf)
485
+
486
+ elif i == 1:
487
+ img = add_blur(img, sf=sf)
488
+
489
+ elif i == 2:
490
+ a, b = img.shape[1], img.shape[0]
491
+ # downsample2
492
+ if random.random() < 0.75:
493
+ sf1 = random.uniform(1, 2 * sf)
494
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
495
+ interpolation=random.choice([1, 2, 3]))
496
+ else:
497
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
498
+ k_shifted = shift_pixel(k, sf)
499
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
500
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
501
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
502
+ img = np.clip(img, 0.0, 1.0)
503
+
504
+ elif i == 3:
505
+ # downsample3
506
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
507
+ img = np.clip(img, 0.0, 1.0)
508
+
509
+ elif i == 4:
510
+ # add Gaussian noise
511
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
512
+
513
+ elif i == 5:
514
+ # add JPEG noise
515
+ if random.random() < jpeg_prob:
516
+ img = add_JPEG_noise(img)
517
+
518
+ elif i == 6:
519
+ # add processed camera sensor noise
520
+ if random.random() < isp_prob and isp_model is not None:
521
+ with torch.no_grad():
522
+ img, hq = isp_model.forward(img.copy(), hq)
523
+
524
+ # add final JPEG compression noise
525
+ img = add_JPEG_noise(img)
526
+
527
+ # random crop
528
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
529
+
530
+ return img, hq
531
+
532
+
533
+ # todo no isp_model?
534
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
535
+ """
536
+ This is the degradation model of BSRGAN from the paper
537
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
538
+ ----------
539
+ sf: scale factor
540
+ isp_model: camera ISP model
541
+ Returns
542
+ -------
543
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
544
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
545
+ """
546
+ image = util.uint2single(image)
547
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
548
+ sf_ori = sf
549
+
550
+ h1, w1 = image.shape[:2]
551
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
552
+ h, w = image.shape[:2]
553
+
554
+ hq = image.copy()
555
+
556
+ if sf == 4 and random.random() < scale2_prob: # downsample1
557
+ if np.random.rand() < 0.5:
558
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
559
+ interpolation=random.choice([1, 2, 3]))
560
+ else:
561
+ image = util.imresize_np(image, 1 / 2, True)
562
+ image = np.clip(image, 0.0, 1.0)
563
+ sf = 2
564
+
565
+ shuffle_order = random.sample(range(7), 7)
566
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
567
+ if idx1 > idx2: # keep downsample3 last
568
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
569
+
570
+ for i in shuffle_order:
571
+
572
+ if i == 0:
573
+ image = add_blur(image, sf=sf)
574
+
575
+ # elif i == 1:
576
+ # image = add_blur(image, sf=sf)
577
+
578
+ if i == 0:
579
+ pass
580
+
581
+ elif i == 2:
582
+ a, b = image.shape[1], image.shape[0]
583
+ # downsample2
584
+ if random.random() < 0.8:
585
+ sf1 = random.uniform(1, 2 * sf)
586
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
587
+ interpolation=random.choice([1, 2, 3]))
588
+ else:
589
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
590
+ k_shifted = shift_pixel(k, sf)
591
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
592
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
593
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
594
+
595
+ image = np.clip(image, 0.0, 1.0)
596
+
597
+ elif i == 3:
598
+ # downsample3
599
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
600
+ image = np.clip(image, 0.0, 1.0)
601
+
602
+ elif i == 4:
603
+ # add Gaussian noise
604
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
605
+
606
+ elif i == 5:
607
+ # add JPEG noise
608
+ if random.random() < jpeg_prob:
609
+ image = add_JPEG_noise(image)
610
+ #
611
+ # elif i == 6:
612
+ # # add processed camera sensor noise
613
+ # if random.random() < isp_prob and isp_model is not None:
614
+ # with torch.no_grad():
615
+ # img, hq = isp_model.forward(img.copy(), hq)
616
+
617
+ # add final JPEG compression noise
618
+ image = add_JPEG_noise(image)
619
+ image = util.single2uint(image)
620
+ example = {"image": image}
621
+ return example
622
+
623
+
624
+
625
+
626
+ if __name__ == '__main__':
627
+ print("hey")
628
+ img = util.imread_uint('utils/test.png', 3)
629
+ img = img[:448, :448]
630
+ h = img.shape[0] // 4
631
+ print("resizing to", h)
632
+ sf = 4
633
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
634
+ for i in range(20):
635
+ print(i)
636
+ img_hq = img
637
+ img_lq = deg_fn(img)["image"]
638
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
639
+ print(img_lq)
640
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
641
+ print(img_lq.shape)
642
+ print("bicubic", img_lq_bicubic.shape)
643
+ print(img_hq.shape)
644
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
645
+ interpolation=0)
646
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
647
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
648
+ interpolation=0)
649
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
650
+ util.imsave(img_concat, str(i) + '.png')
ldm/modules/image_degradation/utils/test.png ADDED
ldm/modules/image_degradation/utils_image.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ #import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
10
+
11
+
12
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
13
+
14
+
15
+ '''
16
+ # --------------------------------------------
17
+ # Kai Zhang (github: https://github.com/cszn)
18
+ # 03/Mar/2019
19
+ # --------------------------------------------
20
+ # https://github.com/twhui/SRGAN-pyTorch
21
+ # https://github.com/xinntao/BasicSR
22
+ # --------------------------------------------
23
+ '''
24
+
25
+
26
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
27
+
28
+
29
+ def is_image_file(filename):
30
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
+
32
+
33
+ def get_timestamp():
34
+ return datetime.now().strftime('%y%m%d-%H%M%S')
35
+
36
+
37
+ def imshow(x, title=None, cbar=False, figsize=None):
38
+ plt.figure(figsize=figsize)
39
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
40
+ if title:
41
+ plt.title(title)
42
+ if cbar:
43
+ plt.colorbar()
44
+ plt.show()
45
+
46
+
47
+ def surf(Z, cmap='rainbow', figsize=None):
48
+ plt.figure(figsize=figsize)
49
+ ax3 = plt.axes(projection='3d')
50
+
51
+ w, h = Z.shape[:2]
52
+ xx = np.arange(0,w,1)
53
+ yy = np.arange(0,h,1)
54
+ X, Y = np.meshgrid(xx, yy)
55
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
56
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
57
+ plt.show()
58
+
59
+
60
+ '''
61
+ # --------------------------------------------
62
+ # get image pathes
63
+ # --------------------------------------------
64
+ '''
65
+
66
+
67
+ def get_image_paths(dataroot):
68
+ paths = None # return None if dataroot is None
69
+ if dataroot is not None:
70
+ paths = sorted(_get_paths_from_images(dataroot))
71
+ return paths
72
+
73
+
74
+ def _get_paths_from_images(path):
75
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
76
+ images = []
77
+ for dirpath, _, fnames in sorted(os.walk(path)):
78
+ for fname in sorted(fnames):
79
+ if is_image_file(fname):
80
+ img_path = os.path.join(dirpath, fname)
81
+ images.append(img_path)
82
+ assert images, '{:s} has no valid image file'.format(path)
83
+ return images
84
+
85
+
86
+ '''
87
+ # --------------------------------------------
88
+ # split large images into small images
89
+ # --------------------------------------------
90
+ '''
91
+
92
+
93
+ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
94
+ w, h = img.shape[:2]
95
+ patches = []
96
+ if w > p_max and h > p_max:
97
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
98
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
99
+ w1.append(w-p_size)
100
+ h1.append(h-p_size)
101
+ # print(w1)
102
+ # print(h1)
103
+ for i in w1:
104
+ for j in h1:
105
+ patches.append(img[i:i+p_size, j:j+p_size,:])
106
+ else:
107
+ patches.append(img)
108
+
109
+ return patches
110
+
111
+
112
+ def imssave(imgs, img_path):
113
+ """
114
+ imgs: list, N images of size WxHxC
115
+ """
116
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
117
+
118
+ for i, img in enumerate(imgs):
119
+ if img.ndim == 3:
120
+ img = img[:, :, [2, 1, 0]]
121
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
122
+ cv2.imwrite(new_path, img)
123
+
124
+
125
+ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
126
+ """
127
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
128
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
129
+ will be splitted.
130
+ Args:
131
+ original_dataroot:
132
+ taget_dataroot:
133
+ p_size: size of small images
134
+ p_overlap: patch size in training is a good choice
135
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
136
+ """
137
+ paths = get_image_paths(original_dataroot)
138
+ for img_path in paths:
139
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
140
+ img = imread_uint(img_path, n_channels=n_channels)
141
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
142
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
143
+ #if original_dataroot == taget_dataroot:
144
+ #del img_path
145
+
146
+ '''
147
+ # --------------------------------------------
148
+ # makedir
149
+ # --------------------------------------------
150
+ '''
151
+
152
+
153
+ def mkdir(path):
154
+ if not os.path.exists(path):
155
+ os.makedirs(path)
156
+
157
+
158
+ def mkdirs(paths):
159
+ if isinstance(paths, str):
160
+ mkdir(paths)
161
+ else:
162
+ for path in paths:
163
+ mkdir(path)
164
+
165
+
166
+ def mkdir_and_rename(path):
167
+ if os.path.exists(path):
168
+ new_name = path + '_archived_' + get_timestamp()
169
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
170
+ os.rename(path, new_name)
171
+ os.makedirs(path)
172
+
173
+
174
+ '''
175
+ # --------------------------------------------
176
+ # read image from path
177
+ # opencv is fast, but read BGR numpy image
178
+ # --------------------------------------------
179
+ '''
180
+
181
+
182
+ # --------------------------------------------
183
+ # get uint8 image of size HxWxn_channles (RGB)
184
+ # --------------------------------------------
185
+ def imread_uint(path, n_channels=3):
186
+ # input: path
187
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
188
+ if n_channels == 1:
189
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
190
+ img = np.expand_dims(img, axis=2) # HxWx1
191
+ elif n_channels == 3:
192
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
193
+ if img.ndim == 2:
194
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
195
+ else:
196
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
197
+ return img
198
+
199
+
200
+ # --------------------------------------------
201
+ # matlab's imwrite
202
+ # --------------------------------------------
203
+ def imsave(img, img_path):
204
+ img = np.squeeze(img)
205
+ if img.ndim == 3:
206
+ img = img[:, :, [2, 1, 0]]
207
+ cv2.imwrite(img_path, img)
208
+
209
+ def imwrite(img, img_path):
210
+ img = np.squeeze(img)
211
+ if img.ndim == 3:
212
+ img = img[:, :, [2, 1, 0]]
213
+ cv2.imwrite(img_path, img)
214
+
215
+
216
+
217
+ # --------------------------------------------
218
+ # get single image of size HxWxn_channles (BGR)
219
+ # --------------------------------------------
220
+ def read_img(path):
221
+ # read image by cv2
222
+ # return: Numpy float32, HWC, BGR, [0,1]
223
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
224
+ img = img.astype(np.float32) / 255.
225
+ if img.ndim == 2:
226
+ img = np.expand_dims(img, axis=2)
227
+ # some images have 4 channels
228
+ if img.shape[2] > 3:
229
+ img = img[:, :, :3]
230
+ return img
231
+
232
+
233
+ '''
234
+ # --------------------------------------------
235
+ # image format conversion
236
+ # --------------------------------------------
237
+ # numpy(single) <---> numpy(unit)
238
+ # numpy(single) <---> tensor
239
+ # numpy(unit) <---> tensor
240
+ # --------------------------------------------
241
+ '''
242
+
243
+
244
+ # --------------------------------------------
245
+ # numpy(single) [0, 1] <---> numpy(unit)
246
+ # --------------------------------------------
247
+
248
+
249
+ def uint2single(img):
250
+
251
+ return np.float32(img/255.)
252
+
253
+
254
+ def single2uint(img):
255
+
256
+ return np.uint8((img.clip(0, 1)*255.).round())
257
+
258
+
259
+ def uint162single(img):
260
+
261
+ return np.float32(img/65535.)
262
+
263
+
264
+ def single2uint16(img):
265
+
266
+ return np.uint16((img.clip(0, 1)*65535.).round())
267
+
268
+
269
+ # --------------------------------------------
270
+ # numpy(unit) (HxWxC or HxW) <---> tensor
271
+ # --------------------------------------------
272
+
273
+
274
+ # convert uint to 4-dimensional torch tensor
275
+ def uint2tensor4(img):
276
+ if img.ndim == 2:
277
+ img = np.expand_dims(img, axis=2)
278
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
279
+
280
+
281
+ # convert uint to 3-dimensional torch tensor
282
+ def uint2tensor3(img):
283
+ if img.ndim == 2:
284
+ img = np.expand_dims(img, axis=2)
285
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
286
+
287
+
288
+ # convert 2/3/4-dimensional torch tensor to uint
289
+ def tensor2uint(img):
290
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
291
+ if img.ndim == 3:
292
+ img = np.transpose(img, (1, 2, 0))
293
+ return np.uint8((img*255.0).round())
294
+
295
+
296
+ # --------------------------------------------
297
+ # numpy(single) (HxWxC) <---> tensor
298
+ # --------------------------------------------
299
+
300
+
301
+ # convert single (HxWxC) to 3-dimensional torch tensor
302
+ def single2tensor3(img):
303
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
304
+
305
+
306
+ # convert single (HxWxC) to 4-dimensional torch tensor
307
+ def single2tensor4(img):
308
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
309
+
310
+
311
+ # convert torch tensor to single
312
+ def tensor2single(img):
313
+ img = img.data.squeeze().float().cpu().numpy()
314
+ if img.ndim == 3:
315
+ img = np.transpose(img, (1, 2, 0))
316
+
317
+ return img
318
+
319
+ # convert torch tensor to single
320
+ def tensor2single3(img):
321
+ img = img.data.squeeze().float().cpu().numpy()
322
+ if img.ndim == 3:
323
+ img = np.transpose(img, (1, 2, 0))
324
+ elif img.ndim == 2:
325
+ img = np.expand_dims(img, axis=2)
326
+ return img
327
+
328
+
329
+ def single2tensor5(img):
330
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
331
+
332
+
333
+ def single32tensor5(img):
334
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
335
+
336
+
337
+ def single42tensor4(img):
338
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
339
+
340
+
341
+ # from skimage.io import imread, imsave
342
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
343
+ '''
344
+ Converts a torch Tensor into an image Numpy array of BGR channel order
345
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
346
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
347
+ '''
348
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
349
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
350
+ n_dim = tensor.dim()
351
+ if n_dim == 4:
352
+ n_img = len(tensor)
353
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
354
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
355
+ elif n_dim == 3:
356
+ img_np = tensor.numpy()
357
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
358
+ elif n_dim == 2:
359
+ img_np = tensor.numpy()
360
+ else:
361
+ raise TypeError(
362
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
363
+ if out_type == np.uint8:
364
+ img_np = (img_np * 255.0).round()
365
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
366
+ return img_np.astype(out_type)
367
+
368
+
369
+ '''
370
+ # --------------------------------------------
371
+ # Augmentation, flipe and/or rotate
372
+ # --------------------------------------------
373
+ # The following two are enough.
374
+ # (1) augmet_img: numpy image of WxHxC or WxH
375
+ # (2) augment_img_tensor4: tensor image 1xCxWxH
376
+ # --------------------------------------------
377
+ '''
378
+
379
+
380
+ def augment_img(img, mode=0):
381
+ '''Kai Zhang (github: https://github.com/cszn)
382
+ '''
383
+ if mode == 0:
384
+ return img
385
+ elif mode == 1:
386
+ return np.flipud(np.rot90(img))
387
+ elif mode == 2:
388
+ return np.flipud(img)
389
+ elif mode == 3:
390
+ return np.rot90(img, k=3)
391
+ elif mode == 4:
392
+ return np.flipud(np.rot90(img, k=2))
393
+ elif mode == 5:
394
+ return np.rot90(img)
395
+ elif mode == 6:
396
+ return np.rot90(img, k=2)
397
+ elif mode == 7:
398
+ return np.flipud(np.rot90(img, k=3))
399
+
400
+
401
+ def augment_img_tensor4(img, mode=0):
402
+ '''Kai Zhang (github: https://github.com/cszn)
403
+ '''
404
+ if mode == 0:
405
+ return img
406
+ elif mode == 1:
407
+ return img.rot90(1, [2, 3]).flip([2])
408
+ elif mode == 2:
409
+ return img.flip([2])
410
+ elif mode == 3:
411
+ return img.rot90(3, [2, 3])
412
+ elif mode == 4:
413
+ return img.rot90(2, [2, 3]).flip([2])
414
+ elif mode == 5:
415
+ return img.rot90(1, [2, 3])
416
+ elif mode == 6:
417
+ return img.rot90(2, [2, 3])
418
+ elif mode == 7:
419
+ return img.rot90(3, [2, 3]).flip([2])
420
+
421
+
422
+ def augment_img_tensor(img, mode=0):
423
+ '''Kai Zhang (github: https://github.com/cszn)
424
+ '''
425
+ img_size = img.size()
426
+ img_np = img.data.cpu().numpy()
427
+ if len(img_size) == 3:
428
+ img_np = np.transpose(img_np, (1, 2, 0))
429
+ elif len(img_size) == 4:
430
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
431
+ img_np = augment_img(img_np, mode=mode)
432
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
433
+ if len(img_size) == 3:
434
+ img_tensor = img_tensor.permute(2, 0, 1)
435
+ elif len(img_size) == 4:
436
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
437
+
438
+ return img_tensor.type_as(img)
439
+
440
+
441
+ def augment_img_np3(img, mode=0):
442
+ if mode == 0:
443
+ return img
444
+ elif mode == 1:
445
+ return img.transpose(1, 0, 2)
446
+ elif mode == 2:
447
+ return img[::-1, :, :]
448
+ elif mode == 3:
449
+ img = img[::-1, :, :]
450
+ img = img.transpose(1, 0, 2)
451
+ return img
452
+ elif mode == 4:
453
+ return img[:, ::-1, :]
454
+ elif mode == 5:
455
+ img = img[:, ::-1, :]
456
+ img = img.transpose(1, 0, 2)
457
+ return img
458
+ elif mode == 6:
459
+ img = img[:, ::-1, :]
460
+ img = img[::-1, :, :]
461
+ return img
462
+ elif mode == 7:
463
+ img = img[:, ::-1, :]
464
+ img = img[::-1, :, :]
465
+ img = img.transpose(1, 0, 2)
466
+ return img
467
+
468
+
469
+ def augment_imgs(img_list, hflip=True, rot=True):
470
+ # horizontal flip OR rotate
471
+ hflip = hflip and random.random() < 0.5
472
+ vflip = rot and random.random() < 0.5
473
+ rot90 = rot and random.random() < 0.5
474
+
475
+ def _augment(img):
476
+ if hflip:
477
+ img = img[:, ::-1, :]
478
+ if vflip:
479
+ img = img[::-1, :, :]
480
+ if rot90:
481
+ img = img.transpose(1, 0, 2)
482
+ return img
483
+
484
+ return [_augment(img) for img in img_list]
485
+
486
+
487
+ '''
488
+ # --------------------------------------------
489
+ # modcrop and shave
490
+ # --------------------------------------------
491
+ '''
492
+
493
+
494
+ def modcrop(img_in, scale):
495
+ # img_in: Numpy, HWC or HW
496
+ img = np.copy(img_in)
497
+ if img.ndim == 2:
498
+ H, W = img.shape
499
+ H_r, W_r = H % scale, W % scale
500
+ img = img[:H - H_r, :W - W_r]
501
+ elif img.ndim == 3:
502
+ H, W, C = img.shape
503
+ H_r, W_r = H % scale, W % scale
504
+ img = img[:H - H_r, :W - W_r, :]
505
+ else:
506
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
507
+ return img
508
+
509
+
510
+ def shave(img_in, border=0):
511
+ # img_in: Numpy, HWC or HW
512
+ img = np.copy(img_in)
513
+ h, w = img.shape[:2]
514
+ img = img[border:h-border, border:w-border]
515
+ return img
516
+
517
+
518
+ '''
519
+ # --------------------------------------------
520
+ # image processing process on numpy image
521
+ # channel_convert(in_c, tar_type, img_list):
522
+ # rgb2ycbcr(img, only_y=True):
523
+ # bgr2ycbcr(img, only_y=True):
524
+ # ycbcr2rgb(img):
525
+ # --------------------------------------------
526
+ '''
527
+
528
+
529
+ def rgb2ycbcr(img, only_y=True):
530
+ '''same as matlab rgb2ycbcr
531
+ only_y: only return Y channel
532
+ Input:
533
+ uint8, [0, 255]
534
+ float, [0, 1]
535
+ '''
536
+ in_img_type = img.dtype
537
+ img.astype(np.float32)
538
+ if in_img_type != np.uint8:
539
+ img *= 255.
540
+ # convert
541
+ if only_y:
542
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
543
+ else:
544
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
545
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
546
+ if in_img_type == np.uint8:
547
+ rlt = rlt.round()
548
+ else:
549
+ rlt /= 255.
550
+ return rlt.astype(in_img_type)
551
+
552
+
553
+ def ycbcr2rgb(img):
554
+ '''same as matlab ycbcr2rgb
555
+ Input:
556
+ uint8, [0, 255]
557
+ float, [0, 1]
558
+ '''
559
+ in_img_type = img.dtype
560
+ img.astype(np.float32)
561
+ if in_img_type != np.uint8:
562
+ img *= 255.
563
+ # convert
564
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
565
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
566
+ if in_img_type == np.uint8:
567
+ rlt = rlt.round()
568
+ else:
569
+ rlt /= 255.
570
+ return rlt.astype(in_img_type)
571
+
572
+
573
+ def bgr2ycbcr(img, only_y=True):
574
+ '''bgr version of rgb2ycbcr
575
+ only_y: only return Y channel
576
+ Input:
577
+ uint8, [0, 255]
578
+ float, [0, 1]
579
+ '''
580
+ in_img_type = img.dtype
581
+ img.astype(np.float32)
582
+ if in_img_type != np.uint8:
583
+ img *= 255.
584
+ # convert
585
+ if only_y:
586
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
587
+ else:
588
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
589
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
590
+ if in_img_type == np.uint8:
591
+ rlt = rlt.round()
592
+ else:
593
+ rlt /= 255.
594
+ return rlt.astype(in_img_type)
595
+
596
+
597
+ def channel_convert(in_c, tar_type, img_list):
598
+ # conversion among BGR, gray and y
599
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
600
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
601
+ return [np.expand_dims(img, axis=2) for img in gray_list]
602
+ elif in_c == 3 and tar_type == 'y': # BGR to y
603
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
604
+ return [np.expand_dims(img, axis=2) for img in y_list]
605
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
606
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
607
+ else:
608
+ return img_list
609
+
610
+
611
+ '''
612
+ # --------------------------------------------
613
+ # metric, PSNR and SSIM
614
+ # --------------------------------------------
615
+ '''
616
+
617
+
618
+ # --------------------------------------------
619
+ # PSNR
620
+ # --------------------------------------------
621
+ def calculate_psnr(img1, img2, border=0):
622
+ # img1 and img2 have range [0, 255]
623
+ #img1 = img1.squeeze()
624
+ #img2 = img2.squeeze()
625
+ if not img1.shape == img2.shape:
626
+ raise ValueError('Input images must have the same dimensions.')
627
+ h, w = img1.shape[:2]
628
+ img1 = img1[border:h-border, border:w-border]
629
+ img2 = img2[border:h-border, border:w-border]
630
+
631
+ img1 = img1.astype(np.float64)
632
+ img2 = img2.astype(np.float64)
633
+ mse = np.mean((img1 - img2)**2)
634
+ if mse == 0:
635
+ return float('inf')
636
+ return 20 * math.log10(255.0 / math.sqrt(mse))
637
+
638
+
639
+ # --------------------------------------------
640
+ # SSIM
641
+ # --------------------------------------------
642
+ def calculate_ssim(img1, img2, border=0):
643
+ '''calculate SSIM
644
+ the same outputs as MATLAB's
645
+ img1, img2: [0, 255]
646
+ '''
647
+ #img1 = img1.squeeze()
648
+ #img2 = img2.squeeze()
649
+ if not img1.shape == img2.shape:
650
+ raise ValueError('Input images must have the same dimensions.')
651
+ h, w = img1.shape[:2]
652
+ img1 = img1[border:h-border, border:w-border]
653
+ img2 = img2[border:h-border, border:w-border]
654
+
655
+ if img1.ndim == 2:
656
+ return ssim(img1, img2)
657
+ elif img1.ndim == 3:
658
+ if img1.shape[2] == 3:
659
+ ssims = []
660
+ for i in range(3):
661
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
662
+ return np.array(ssims).mean()
663
+ elif img1.shape[2] == 1:
664
+ return ssim(np.squeeze(img1), np.squeeze(img2))
665
+ else:
666
+ raise ValueError('Wrong input image dimensions.')
667
+
668
+
669
+ def ssim(img1, img2):
670
+ C1 = (0.01 * 255)**2
671
+ C2 = (0.03 * 255)**2
672
+
673
+ img1 = img1.astype(np.float64)
674
+ img2 = img2.astype(np.float64)
675
+ kernel = cv2.getGaussianKernel(11, 1.5)
676
+ window = np.outer(kernel, kernel.transpose())
677
+
678
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
679
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
680
+ mu1_sq = mu1**2
681
+ mu2_sq = mu2**2
682
+ mu1_mu2 = mu1 * mu2
683
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
684
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
685
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
686
+
687
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
688
+ (sigma1_sq + sigma2_sq + C2))
689
+ return ssim_map.mean()
690
+
691
+
692
+ '''
693
+ # --------------------------------------------
694
+ # matlab's bicubic imresize (numpy and torch) [0, 1]
695
+ # --------------------------------------------
696
+ '''
697
+
698
+
699
+ # matlab 'imresize' function, now only support 'bicubic'
700
+ def cubic(x):
701
+ absx = torch.abs(x)
702
+ absx2 = absx**2
703
+ absx3 = absx**3
704
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
705
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
706
+
707
+
708
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
709
+ if (scale < 1) and (antialiasing):
710
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
711
+ kernel_width = kernel_width / scale
712
+
713
+ # Output-space coordinates
714
+ x = torch.linspace(1, out_length, out_length)
715
+
716
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
717
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
718
+ # space maps to 1.5 in input space.
719
+ u = x / scale + 0.5 * (1 - 1 / scale)
720
+
721
+ # What is the left-most pixel that can be involved in the computation?
722
+ left = torch.floor(u - kernel_width / 2)
723
+
724
+ # What is the maximum number of pixels that can be involved in the
725
+ # computation? Note: it's OK to use an extra pixel here; if the
726
+ # corresponding weights are all zero, it will be eliminated at the end
727
+ # of this function.
728
+ P = math.ceil(kernel_width) + 2
729
+
730
+ # The indices of the input pixels involved in computing the k-th output
731
+ # pixel are in row k of the indices matrix.
732
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
733
+ 1, P).expand(out_length, P)
734
+
735
+ # The weights used to compute the k-th output pixel are in row k of the
736
+ # weights matrix.
737
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
738
+ # apply cubic kernel
739
+ if (scale < 1) and (antialiasing):
740
+ weights = scale * cubic(distance_to_center * scale)
741
+ else:
742
+ weights = cubic(distance_to_center)
743
+ # Normalize the weights matrix so that each row sums to 1.
744
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
745
+ weights = weights / weights_sum.expand(out_length, P)
746
+
747
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
748
+ weights_zero_tmp = torch.sum((weights == 0), 0)
749
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
750
+ indices = indices.narrow(1, 1, P - 2)
751
+ weights = weights.narrow(1, 1, P - 2)
752
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
753
+ indices = indices.narrow(1, 0, P - 2)
754
+ weights = weights.narrow(1, 0, P - 2)
755
+ weights = weights.contiguous()
756
+ indices = indices.contiguous()
757
+ sym_len_s = -indices.min() + 1
758
+ sym_len_e = indices.max() - in_length
759
+ indices = indices + sym_len_s - 1
760
+ return weights, indices, int(sym_len_s), int(sym_len_e)
761
+
762
+
763
+ # --------------------------------------------
764
+ # imresize for tensor image [0, 1]
765
+ # --------------------------------------------
766
+ def imresize(img, scale, antialiasing=True):
767
+ # Now the scale should be the same for H and W
768
+ # input: img: pytorch tensor, CHW or HW [0,1]
769
+ # output: CHW or HW [0,1] w/o round
770
+ need_squeeze = True if img.dim() == 2 else False
771
+ if need_squeeze:
772
+ img.unsqueeze_(0)
773
+ in_C, in_H, in_W = img.size()
774
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
775
+ kernel_width = 4
776
+ kernel = 'cubic'
777
+
778
+ # Return the desired dimension order for performing the resize. The
779
+ # strategy is to perform the resize first along the dimension with the
780
+ # smallest scale factor.
781
+ # Now we do not support this.
782
+
783
+ # get weights and indices
784
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
785
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
786
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
787
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
788
+ # process H dimension
789
+ # symmetric copying
790
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
791
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
792
+
793
+ sym_patch = img[:, :sym_len_Hs, :]
794
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
795
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
796
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
797
+
798
+ sym_patch = img[:, -sym_len_He:, :]
799
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
800
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
801
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
802
+
803
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
804
+ kernel_width = weights_H.size(1)
805
+ for i in range(out_H):
806
+ idx = int(indices_H[i][0])
807
+ for j in range(out_C):
808
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
809
+
810
+ # process W dimension
811
+ # symmetric copying
812
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
813
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
814
+
815
+ sym_patch = out_1[:, :, :sym_len_Ws]
816
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
817
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
818
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
819
+
820
+ sym_patch = out_1[:, :, -sym_len_We:]
821
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
822
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
823
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
824
+
825
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
826
+ kernel_width = weights_W.size(1)
827
+ for i in range(out_W):
828
+ idx = int(indices_W[i][0])
829
+ for j in range(out_C):
830
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
831
+ if need_squeeze:
832
+ out_2.squeeze_()
833
+ return out_2
834
+
835
+
836
+ # --------------------------------------------
837
+ # imresize for numpy image [0, 1]
838
+ # --------------------------------------------
839
+ def imresize_np(img, scale, antialiasing=True):
840
+ # Now the scale should be the same for H and W
841
+ # input: img: Numpy, HWC or HW [0,1]
842
+ # output: HWC or HW [0,1] w/o round
843
+ img = torch.from_numpy(img)
844
+ need_squeeze = True if img.dim() == 2 else False
845
+ if need_squeeze:
846
+ img.unsqueeze_(2)
847
+
848
+ in_H, in_W, in_C = img.size()
849
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
850
+ kernel_width = 4
851
+ kernel = 'cubic'
852
+
853
+ # Return the desired dimension order for performing the resize. The
854
+ # strategy is to perform the resize first along the dimension with the
855
+ # smallest scale factor.
856
+ # Now we do not support this.
857
+
858
+ # get weights and indices
859
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
860
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
861
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
862
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
863
+ # process H dimension
864
+ # symmetric copying
865
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
866
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
867
+
868
+ sym_patch = img[:sym_len_Hs, :, :]
869
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
870
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
871
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
872
+
873
+ sym_patch = img[-sym_len_He:, :, :]
874
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
875
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
876
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
877
+
878
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
879
+ kernel_width = weights_H.size(1)
880
+ for i in range(out_H):
881
+ idx = int(indices_H[i][0])
882
+ for j in range(out_C):
883
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
884
+
885
+ # process W dimension
886
+ # symmetric copying
887
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
888
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
889
+
890
+ sym_patch = out_1[:, :sym_len_Ws, :]
891
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
892
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
893
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
894
+
895
+ sym_patch = out_1[:, -sym_len_We:, :]
896
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
897
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
898
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
899
+
900
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
901
+ kernel_width = weights_W.size(1)
902
+ for i in range(out_W):
903
+ idx = int(indices_W[i][0])
904
+ for j in range(out_C):
905
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
906
+ if need_squeeze:
907
+ out_2.squeeze_()
908
+
909
+ return out_2.numpy()
910
+
911
+
912
+ if __name__ == '__main__':
913
+ print('---')
914
+ # img = imread_uint('test.bmp', 3)
915
+ # img = uint2single(img)
916
+ # img_bicubic = imresize_np(img, 1/4)
ldm/modules/util.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ActNorm(nn.Module):
5
+ def __init__(self, num_features, logdet=False, affine=True,
6
+ allow_reverse_init=False):
7
+ assert affine
8
+ super().__init__()
9
+ self.logdet = logdet
10
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
11
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
12
+ self.allow_reverse_init = allow_reverse_init
13
+
14
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
15
+
16
+ def initialize(self, input):
17
+ with torch.no_grad():
18
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
19
+ mean = (
20
+ flatten.mean(1)
21
+ .unsqueeze(1)
22
+ .unsqueeze(2)
23
+ .unsqueeze(3)
24
+ .permute(1, 0, 2, 3)
25
+ )
26
+ std = (
27
+ flatten.std(1)
28
+ .unsqueeze(1)
29
+ .unsqueeze(2)
30
+ .unsqueeze(3)
31
+ .permute(1, 0, 2, 3)
32
+ )
33
+
34
+ self.loc.data.copy_(-mean)
35
+ self.scale.data.copy_(1 / (std + 1e-6))
36
+
37
+ def forward(self, input, reverse=False):
38
+ if reverse:
39
+ return self.reverse(input)
40
+ if len(input.shape) == 2:
41
+ input = input[:,:,None,None]
42
+ squeeze = True
43
+ else:
44
+ squeeze = False
45
+
46
+ _, _, height, width = input.shape
47
+
48
+ if self.training and self.initialized.item() == 0:
49
+ self.initialize(input)
50
+ self.initialized.fill_(1)
51
+
52
+ h = self.scale * (input + self.loc)
53
+
54
+ if squeeze:
55
+ h = h.squeeze(-1).squeeze(-1)
56
+
57
+ if self.logdet:
58
+ log_abs = torch.log(torch.abs(self.scale))
59
+ logdet = height*width*torch.sum(log_abs)
60
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
61
+ return h, logdet
62
+
63
+ return h
64
+
65
+ def reverse(self, output):
66
+ if self.training and self.initialized.item() == 0:
67
+ if not self.allow_reverse_init:
68
+ raise RuntimeError(
69
+ "Initializing ActNorm in reverse direction is "
70
+ "disabled by default. Use allow_reverse_init=True to enable."
71
+ )
72
+ else:
73
+ self.initialize(output)
74
+ self.initialized.fill_(1)
75
+
76
+ if len(output.shape) == 2:
77
+ output = output[:,:,None,None]
78
+ squeeze = True
79
+ else:
80
+ squeeze = False
81
+
82
+ h = output / self.scale - self.loc
83
+
84
+ if squeeze:
85
+ h = h.squeeze(-1).squeeze(-1)
86
+ return h
ldm/modules/x_transformer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2
+ import torch
3
+ from torch import nn, einsum
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+ from inspect import isfunction
7
+ from collections import namedtuple
8
+ from einops import rearrange, repeat, reduce
9
+
10
+ # constants
11
+
12
+ DEFAULT_DIM_HEAD = 64
13
+
14
+ Intermediates = namedtuple('Intermediates', [
15
+ 'pre_softmax_attn',
16
+ 'post_softmax_attn'
17
+ ])
18
+
19
+ LayerIntermediates = namedtuple('Intermediates', [
20
+ 'hiddens',
21
+ 'attn_intermediates'
22
+ ])
23
+
24
+
25
+ class AbsolutePositionalEmbedding(nn.Module):
26
+ def __init__(self, dim, max_seq_len):
27
+ super().__init__()
28
+ self.emb = nn.Embedding(max_seq_len, dim)
29
+ self.init_()
30
+
31
+ def init_(self):
32
+ nn.init.normal_(self.emb.weight, std=0.02)
33
+
34
+ def forward(self, x):
35
+ n = torch.arange(x.shape[1], device=x.device)
36
+ return self.emb(n)[None, :, :]
37
+
38
+
39
+ class FixedPositionalEmbedding(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer('inv_freq', inv_freq)
44
+
45
+ def forward(self, x, seq_dim=1, offset=0):
46
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
47
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
48
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
49
+ return emb[None, :, :]
50
+
51
+
52
+ # helpers
53
+
54
+ def exists(val):
55
+ return val is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def always(val):
65
+ def inner(*args, **kwargs):
66
+ return val
67
+ return inner
68
+
69
+
70
+ def not_equals(val):
71
+ def inner(x):
72
+ return x != val
73
+ return inner
74
+
75
+
76
+ def equals(val):
77
+ def inner(x):
78
+ return x == val
79
+ return inner
80
+
81
+
82
+ def max_neg_value(tensor):
83
+ return -torch.finfo(tensor.dtype).max
84
+
85
+
86
+ # keyword argument helpers
87
+
88
+ def pick_and_pop(keys, d):
89
+ values = list(map(lambda key: d.pop(key), keys))
90
+ return dict(zip(keys, values))
91
+
92
+
93
+ def group_dict_by_key(cond, d):
94
+ return_val = [dict(), dict()]
95
+ for key in d.keys():
96
+ match = bool(cond(key))
97
+ ind = int(not match)
98
+ return_val[ind][key] = d[key]
99
+ return (*return_val,)
100
+
101
+
102
+ def string_begins_with(prefix, str):
103
+ return str.startswith(prefix)
104
+
105
+
106
+ def group_by_key_prefix(prefix, d):
107
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
108
+
109
+
110
+ def groupby_prefix_and_trim(prefix, d):
111
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
112
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
113
+ return kwargs_without_prefix, kwargs
114
+
115
+
116
+ # classes
117
+ class Scale(nn.Module):
118
+ def __init__(self, value, fn):
119
+ super().__init__()
120
+ self.value = value
121
+ self.fn = fn
122
+
123
+ def forward(self, x, **kwargs):
124
+ x, *rest = self.fn(x, **kwargs)
125
+ return (x * self.value, *rest)
126
+
127
+
128
+ class Rezero(nn.Module):
129
+ def __init__(self, fn):
130
+ super().__init__()
131
+ self.fn = fn
132
+ self.g = nn.Parameter(torch.zeros(1))
133
+
134
+ def forward(self, x, **kwargs):
135
+ x, *rest = self.fn(x, **kwargs)
136
+ return (x * self.g, *rest)
137
+
138
+
139
+ class ScaleNorm(nn.Module):
140
+ def __init__(self, dim, eps=1e-5):
141
+ super().__init__()
142
+ self.scale = dim ** -0.5
143
+ self.eps = eps
144
+ self.g = nn.Parameter(torch.ones(1))
145
+
146
+ def forward(self, x):
147
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
148
+ return x / norm.clamp(min=self.eps) * self.g
149
+
150
+
151
+ class RMSNorm(nn.Module):
152
+ def __init__(self, dim, eps=1e-8):
153
+ super().__init__()
154
+ self.scale = dim ** -0.5
155
+ self.eps = eps
156
+ self.g = nn.Parameter(torch.ones(dim))
157
+
158
+ def forward(self, x):
159
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
160
+ return x / norm.clamp(min=self.eps) * self.g
161
+
162
+
163
+ class Residual(nn.Module):
164
+ def forward(self, x, residual):
165
+ return x + residual
166
+
167
+
168
+ class GRUGating(nn.Module):
169
+ def __init__(self, dim):
170
+ super().__init__()
171
+ self.gru = nn.GRUCell(dim, dim)
172
+
173
+ def forward(self, x, residual):
174
+ gated_output = self.gru(
175
+ rearrange(x, 'b n d -> (b n) d'),
176
+ rearrange(residual, 'b n d -> (b n) d')
177
+ )
178
+
179
+ return gated_output.reshape_as(x)
180
+
181
+
182
+ # feedforward
183
+
184
+ class GEGLU(nn.Module):
185
+ def __init__(self, dim_in, dim_out):
186
+ super().__init__()
187
+ self.proj = nn.Linear(dim_in, dim_out * 2)
188
+
189
+ def forward(self, x):
190
+ x, gate = self.proj(x).chunk(2, dim=-1)
191
+ return x * F.gelu(gate)
192
+
193
+
194
+ class FeedForward(nn.Module):
195
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
196
+ super().__init__()
197
+ inner_dim = int(dim * mult)
198
+ dim_out = default(dim_out, dim)
199
+ project_in = nn.Sequential(
200
+ nn.Linear(dim, inner_dim),
201
+ nn.GELU()
202
+ ) if not glu else GEGLU(dim, inner_dim)
203
+
204
+ self.net = nn.Sequential(
205
+ project_in,
206
+ nn.Dropout(dropout),
207
+ nn.Linear(inner_dim, dim_out)
208
+ )
209
+
210
+ def forward(self, x):
211
+ return self.net(x)
212
+
213
+
214
+ # attention.
215
+ class Attention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ dim,
219
+ dim_head=DEFAULT_DIM_HEAD,
220
+ heads=8,
221
+ causal=False,
222
+ mask=None,
223
+ talking_heads=False,
224
+ sparse_topk=None,
225
+ use_entmax15=False,
226
+ num_mem_kv=0,
227
+ dropout=0.,
228
+ on_attn=False
229
+ ):
230
+ super().__init__()
231
+ if use_entmax15:
232
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
233
+ self.scale = dim_head ** -0.5
234
+ self.heads = heads
235
+ self.causal = causal
236
+ self.mask = mask
237
+
238
+ inner_dim = dim_head * heads
239
+
240
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
241
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
242
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
243
+ self.dropout = nn.Dropout(dropout)
244
+
245
+ # talking heads
246
+ self.talking_heads = talking_heads
247
+ if talking_heads:
248
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
250
+
251
+ # explicit topk sparse attention
252
+ self.sparse_topk = sparse_topk
253
+
254
+ # entmax
255
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
256
+ self.attn_fn = F.softmax
257
+
258
+ # add memory key / values
259
+ self.num_mem_kv = num_mem_kv
260
+ if num_mem_kv > 0:
261
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
263
+
264
+ # attention on attention
265
+ self.attn_on_attn = on_attn
266
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
267
+
268
+ def forward(
269
+ self,
270
+ x,
271
+ context=None,
272
+ mask=None,
273
+ context_mask=None,
274
+ rel_pos=None,
275
+ sinusoidal_emb=None,
276
+ prev_attn=None,
277
+ mem=None
278
+ ):
279
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
280
+ kv_input = default(context, x)
281
+
282
+ q_input = x
283
+ k_input = kv_input
284
+ v_input = kv_input
285
+
286
+ if exists(mem):
287
+ k_input = torch.cat((mem, k_input), dim=-2)
288
+ v_input = torch.cat((mem, v_input), dim=-2)
289
+
290
+ if exists(sinusoidal_emb):
291
+ # in shortformer, the query would start at a position offset depending on the past cached memory
292
+ offset = k_input.shape[-2] - q_input.shape[-2]
293
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
294
+ k_input = k_input + sinusoidal_emb(k_input)
295
+
296
+ q = self.to_q(q_input)
297
+ k = self.to_k(k_input)
298
+ v = self.to_v(v_input)
299
+
300
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
301
+
302
+ input_mask = None
303
+ if any(map(exists, (mask, context_mask))):
304
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
305
+ k_mask = q_mask if not exists(context) else context_mask
306
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
307
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
308
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
309
+ input_mask = q_mask * k_mask
310
+
311
+ if self.num_mem_kv > 0:
312
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
313
+ k = torch.cat((mem_k, k), dim=-2)
314
+ v = torch.cat((mem_v, v), dim=-2)
315
+ if exists(input_mask):
316
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
317
+
318
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
319
+ mask_value = max_neg_value(dots)
320
+
321
+ if exists(prev_attn):
322
+ dots = dots + prev_attn
323
+
324
+ pre_softmax_attn = dots
325
+
326
+ if talking_heads:
327
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
328
+
329
+ if exists(rel_pos):
330
+ dots = rel_pos(dots)
331
+
332
+ if exists(input_mask):
333
+ dots.masked_fill_(~input_mask, mask_value)
334
+ del input_mask
335
+
336
+ if self.causal:
337
+ i, j = dots.shape[-2:]
338
+ r = torch.arange(i, device=device)
339
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
340
+ mask = F.pad(mask, (j - i, 0), value=False)
341
+ dots.masked_fill_(mask, mask_value)
342
+ del mask
343
+
344
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
345
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
346
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
347
+ mask = dots < vk
348
+ dots.masked_fill_(mask, mask_value)
349
+ del mask
350
+
351
+ attn = self.attn_fn(dots, dim=-1)
352
+ post_softmax_attn = attn
353
+
354
+ attn = self.dropout(attn)
355
+
356
+ if talking_heads:
357
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
358
+
359
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
360
+ out = rearrange(out, 'b h n d -> b n (h d)')
361
+
362
+ intermediates = Intermediates(
363
+ pre_softmax_attn=pre_softmax_attn,
364
+ post_softmax_attn=post_softmax_attn
365
+ )
366
+
367
+ return self.to_out(out), intermediates
368
+
369
+
370
+ class AttentionLayers(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ depth,
375
+ heads=8,
376
+ causal=False,
377
+ cross_attend=False,
378
+ only_cross=False,
379
+ use_scalenorm=False,
380
+ use_rmsnorm=False,
381
+ use_rezero=False,
382
+ rel_pos_num_buckets=32,
383
+ rel_pos_max_distance=128,
384
+ position_infused_attn=False,
385
+ custom_layers=None,
386
+ sandwich_coef=None,
387
+ par_ratio=None,
388
+ residual_attn=False,
389
+ cross_residual_attn=False,
390
+ macaron=False,
391
+ pre_norm=True,
392
+ gate_residual=False,
393
+ **kwargs
394
+ ):
395
+ super().__init__()
396
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
397
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
398
+
399
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
400
+
401
+ self.dim = dim
402
+ self.depth = depth
403
+ self.layers = nn.ModuleList([])
404
+
405
+ self.has_pos_emb = position_infused_attn
406
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
407
+ self.rotary_pos_emb = always(None)
408
+
409
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410
+ self.rel_pos = None
411
+
412
+ self.pre_norm = pre_norm
413
+
414
+ self.residual_attn = residual_attn
415
+ self.cross_residual_attn = cross_residual_attn
416
+
417
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
418
+ norm_class = RMSNorm if use_rmsnorm else norm_class
419
+ norm_fn = partial(norm_class, dim)
420
+
421
+ norm_fn = nn.Identity if use_rezero else norm_fn
422
+ branch_fn = Rezero if use_rezero else None
423
+
424
+ if cross_attend and not only_cross:
425
+ default_block = ('a', 'c', 'f')
426
+ elif cross_attend and only_cross:
427
+ default_block = ('c', 'f')
428
+ else:
429
+ default_block = ('a', 'f')
430
+
431
+ if macaron:
432
+ default_block = ('f',) + default_block
433
+
434
+ if exists(custom_layers):
435
+ layer_types = custom_layers
436
+ elif exists(par_ratio):
437
+ par_depth = depth * len(default_block)
438
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
439
+ default_block = tuple(filter(not_equals('f'), default_block))
440
+ par_attn = par_depth // par_ratio
441
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
442
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
443
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
444
+ par_block = default_block + ('f',) * (par_width - len(default_block))
445
+ par_head = par_block * par_attn
446
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
447
+ elif exists(sandwich_coef):
448
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
449
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
450
+ else:
451
+ layer_types = default_block * depth
452
+
453
+ self.layer_types = layer_types
454
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
455
+
456
+ for layer_type in self.layer_types:
457
+ if layer_type == 'a':
458
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
459
+ elif layer_type == 'c':
460
+ layer = Attention(dim, heads=heads, **attn_kwargs)
461
+ elif layer_type == 'f':
462
+ layer = FeedForward(dim, **ff_kwargs)
463
+ layer = layer if not macaron else Scale(0.5, layer)
464
+ else:
465
+ raise Exception(f'invalid layer type {layer_type}')
466
+
467
+ if isinstance(layer, Attention) and exists(branch_fn):
468
+ layer = branch_fn(layer)
469
+
470
+ if gate_residual:
471
+ residual_fn = GRUGating(dim)
472
+ else:
473
+ residual_fn = Residual()
474
+
475
+ self.layers.append(nn.ModuleList([
476
+ norm_fn(),
477
+ layer,
478
+ residual_fn
479
+ ]))
480
+
481
+ def forward(
482
+ self,
483
+ x,
484
+ context=None,
485
+ mask=None,
486
+ context_mask=None,
487
+ mems=None,
488
+ return_hiddens=False
489
+ ):
490
+ hiddens = []
491
+ intermediates = []
492
+ prev_attn = None
493
+ prev_cross_attn = None
494
+
495
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
496
+
497
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
498
+ is_last = ind == (len(self.layers) - 1)
499
+
500
+ if layer_type == 'a':
501
+ hiddens.append(x)
502
+ layer_mem = mems.pop(0)
503
+
504
+ residual = x
505
+
506
+ if self.pre_norm:
507
+ x = norm(x)
508
+
509
+ if layer_type == 'a':
510
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
511
+ prev_attn=prev_attn, mem=layer_mem)
512
+ elif layer_type == 'c':
513
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
514
+ elif layer_type == 'f':
515
+ out = block(x)
516
+
517
+ x = residual_fn(out, residual)
518
+
519
+ if layer_type in ('a', 'c'):
520
+ intermediates.append(inter)
521
+
522
+ if layer_type == 'a' and self.residual_attn:
523
+ prev_attn = inter.pre_softmax_attn
524
+ elif layer_type == 'c' and self.cross_residual_attn:
525
+ prev_cross_attn = inter.pre_softmax_attn
526
+
527
+ if not self.pre_norm and not is_last:
528
+ x = norm(x)
529
+
530
+ if return_hiddens:
531
+ intermediates = LayerIntermediates(
532
+ hiddens=hiddens,
533
+ attn_intermediates=intermediates
534
+ )
535
+
536
+ return x, intermediates
537
+
538
+ return x
539
+
540
+
541
+ class Encoder(AttentionLayers):
542
+ def __init__(self, **kwargs):
543
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
544
+ super().__init__(causal=False, **kwargs)
545
+
546
+
547
+
548
+ class TransformerWrapper(nn.Module):
549
+ def __init__(
550
+ self,
551
+ *,
552
+ num_tokens,
553
+ max_seq_len,
554
+ attn_layers,
555
+ emb_dim=None,
556
+ max_mem_len=0.,
557
+ emb_dropout=0.,
558
+ num_memory_tokens=None,
559
+ tie_embedding=False,
560
+ use_pos_emb=True
561
+ ):
562
+ super().__init__()
563
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
564
+
565
+ dim = attn_layers.dim
566
+ emb_dim = default(emb_dim, dim)
567
+
568
+ self.max_seq_len = max_seq_len
569
+ self.max_mem_len = max_mem_len
570
+ self.num_tokens = num_tokens
571
+
572
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
573
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
574
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
575
+ self.emb_dropout = nn.Dropout(emb_dropout)
576
+
577
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
578
+ self.attn_layers = attn_layers
579
+ self.norm = nn.LayerNorm(dim)
580
+
581
+ self.init_()
582
+
583
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
584
+
585
+ # memory tokens (like [cls]) from Memory Transformers paper
586
+ num_memory_tokens = default(num_memory_tokens, 0)
587
+ self.num_memory_tokens = num_memory_tokens
588
+ if num_memory_tokens > 0:
589
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
590
+
591
+ # let funnel encoder know number of memory tokens, if specified
592
+ if hasattr(attn_layers, 'num_memory_tokens'):
593
+ attn_layers.num_memory_tokens = num_memory_tokens
594
+
595
+ def init_(self):
596
+ nn.init.normal_(self.token_emb.weight, std=0.02)
597
+
598
+ def forward(
599
+ self,
600
+ x,
601
+ return_embeddings=False,
602
+ mask=None,
603
+ return_mems=False,
604
+ return_attn=False,
605
+ mems=None,
606
+ **kwargs
607
+ ):
608
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
609
+ x = self.token_emb(x)
610
+ x += self.pos_emb(x)
611
+ x = self.emb_dropout(x)
612
+
613
+ x = self.project_emb(x)
614
+
615
+ if num_mem > 0:
616
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
617
+ x = torch.cat((mem, x), dim=1)
618
+
619
+ # auto-handle masking after appending memory tokens
620
+ if exists(mask):
621
+ mask = F.pad(mask, (num_mem, 0), value=True)
622
+
623
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
624
+ x = self.norm(x)
625
+
626
+ mem, x = x[:, :num_mem], x[:, num_mem:]
627
+
628
+ out = self.to_logits(x) if not return_embeddings else x
629
+
630
+ if return_mems:
631
+ hiddens = intermediates.hiddens
632
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
633
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
634
+ return out, new_mems
635
+
636
+ if return_attn:
637
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
638
+ return out, attn_maps
639
+
640
+ return out
641
+