Dionyssos commited on
Commit
a032fce
·
1 Parent(s): d9ecbcf

factor diffusion

Browse files
Modules/diffusion/diffusion.py DELETED
@@ -1,85 +0,0 @@
1
- from math import pi
2
- from random import randint
3
- from typing import Any, Optional, Sequence, Tuple, Union
4
-
5
- import torch
6
- from einops import rearrange
7
- from torch import Tensor, nn
8
- from tqdm import tqdm
9
-
10
- from .utils import *
11
- from .sampler import *
12
-
13
- """
14
- Diffusion Classes (generic for 1d data)
15
- """
16
-
17
-
18
- class Model1d(nn.Module):
19
- def __init__(self, unet_type: str = "base", **kwargs):
20
- super().__init__()
21
- diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
- self.unet = None
23
- self.diffusion = None
24
-
25
- def forward(self, x: Tensor, **kwargs) -> Tensor:
26
- return self.diffusion(x, **kwargs)
27
-
28
- def sample(self, *args, **kwargs) -> Tensor:
29
- return self.diffusion.sample(*args, **kwargs)
30
-
31
-
32
- """
33
- Audio Diffusion Classes (specific for 1d audio data)
34
- """
35
-
36
-
37
- def get_default_model_kwargs():
38
- return dict(
39
- channels=128,
40
- patch_size=16,
41
- multipliers=[1, 2, 4, 4, 4, 4, 4],
42
- factors=[4, 4, 4, 2, 2, 2],
43
- num_blocks=[2, 2, 2, 2, 2, 2],
44
- attentions=[0, 0, 0, 1, 1, 1, 1],
45
- attention_heads=8,
46
- attention_features=64,
47
- attention_multiplier=2,
48
- attention_use_rel_pos=False,
49
- diffusion_type="v",
50
- diffusion_sigma_distribution=UniformDistribution(),
51
- )
52
-
53
-
54
- def get_default_sampling_kwargs():
55
- return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
-
57
- class AudioDiffusionConditional(Model1d):
58
- def __init__(
59
- self,
60
- embedding_features: int,
61
- embedding_max_length: int,
62
- embedding_mask_proba: float = 0.1,
63
- **kwargs,
64
- ):
65
- self.embedding_mask_proba = embedding_mask_proba
66
- default_kwargs = dict(
67
- **get_default_model_kwargs(),
68
- unet_type="cfg",
69
- context_embedding_features=embedding_features,
70
- context_embedding_max_length=embedding_max_length,
71
- )
72
- super().__init__(**{**default_kwargs, **kwargs})
73
-
74
- def forward(self, *args, **kwargs):
75
- default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
76
- return super().forward(*args, **{**default_kwargs, **kwargs})
77
-
78
- def sample(self, *args, **kwargs):
79
- default_kwargs = dict(
80
- **get_default_sampling_kwargs(),
81
- embedding_scale=5.0,
82
- )
83
- return super().sample(*args, **{**default_kwargs, **kwargs})
84
-
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Modules/diffusion/modules.py CHANGED
@@ -1,8 +1,5 @@
1
  from math import floor, log, pi
2
- from typing import Any, List, Optional, Sequence, Tuple, Union
3
-
4
- from .utils import *
5
-
6
  import torch
7
  import torch.nn as nn
8
  from einops import rearrange, reduce, repeat
@@ -11,9 +8,10 @@ from einops_exts import rearrange_many
11
  from torch import Tensor, einsum
12
 
13
 
14
- """
15
- Utils
16
- """
 
17
 
18
  class AdaLayerNorm(nn.Module):
19
  def __init__(self, style_dim, channels, eps=1e-5):
@@ -38,6 +36,9 @@ class AdaLayerNorm(nn.Module):
38
  return x.transpose(1, -1).transpose(-1, -2)
39
 
40
  class StyleTransformer1d(nn.Module):
 
 
 
41
  def __init__(
42
  self,
43
  num_layers: int,
@@ -48,14 +49,14 @@ class StyleTransformer1d(nn.Module):
48
  use_context_time: bool = True,
49
  use_rel_pos: bool = False,
50
  context_features_multiplier: int = 1,
51
- rel_pos_num_buckets: Optional[int] = None,
52
- rel_pos_max_distance: Optional[int] = None,
53
- context_features: Optional[int] = None,
54
- context_embedding_features: Optional[int] = None,
55
- embedding_max_length: int = 512,
56
  ):
57
  super().__init__()
58
-
59
  self.blocks = nn.ModuleList(
60
  [
61
  StyleTransformerBlock(
@@ -65,8 +66,8 @@ class StyleTransformer1d(nn.Module):
65
  multiplier=multiplier,
66
  style_dim=context_features,
67
  use_rel_pos=use_rel_pos,
68
- rel_pos_num_buckets=rel_pos_num_buckets,
69
- rel_pos_max_distance=rel_pos_max_distance,
70
  )
71
  for i in range(num_layers)
72
  ]
@@ -81,11 +82,14 @@ class StyleTransformer1d(nn.Module):
81
  ),
82
  )
83
 
84
- use_context_features = exists(context_features)
85
  self.use_context_features = use_context_features
86
  self.use_context_time = use_context_time
87
 
88
  if use_context_time or use_context_features:
 
 
 
89
  context_mapping_features = channels + context_embedding_features
90
 
91
  self.to_mapping = nn.Sequential(
@@ -96,7 +100,7 @@ class StyleTransformer1d(nn.Module):
96
  )
97
 
98
  if use_context_time:
99
- assert exists(context_mapping_features)
100
  self.to_time = nn.Sequential(
101
  TimePositionalEmbedding(
102
  dim=channels, out_features=context_mapping_features
@@ -105,7 +109,7 @@ class StyleTransformer1d(nn.Module):
105
  )
106
 
107
  if use_context_features:
108
- assert exists(context_features) and exists(context_mapping_features)
109
  self.to_features = nn.Sequential(
110
  nn.Linear(
111
  in_features=context_features, out_features=context_mapping_features
@@ -119,23 +123,23 @@ class StyleTransformer1d(nn.Module):
119
 
120
 
121
  def get_mapping(
122
- self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
- ) -> Optional[Tensor]:
 
124
  """Combines context time features and features into mapping"""
125
  items, mapping = [], None
126
  # Compute time features
127
  if self.use_context_time:
128
- assert_message = "use_context_time=True but no time features provided"
129
- assert exists(time), assert_message
130
  items += [self.to_time(time)]
131
  # Compute features
132
  if self.use_context_features:
133
- assert_message = "context_features exists but no features provided"
134
- assert exists(features), assert_message
135
  items += [self.to_features(features)]
136
 
137
  # Compute joint mapping
138
  if self.use_context_time or self.use_context_features:
 
139
  mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
  mapping = self.to_mapping(mapping)
141
 
@@ -160,8 +164,8 @@ class StyleTransformer1d(nn.Module):
160
  def forward(self, x: Tensor,
161
  time: Tensor,
162
  embedding_mask_proba: float = 0.0,
163
- embedding: Optional[Tensor] = None,
164
- features: Optional[Tensor] = None,
165
  embedding_scale: float = 1.0) -> Tensor:
166
 
167
  b, device = embedding.shape[0], embedding.device
@@ -174,13 +178,18 @@ class StyleTransformer1d(nn.Module):
174
  embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
 
176
  if embedding_scale != 1.0:
177
- # Compute both normal and fixed embedding outputs
 
178
  out = self.run(x, time, embedding=embedding, features=features)
179
  out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
- # Scale conditional output using classifier-free guidance
 
181
  return out_masked + (out - out_masked) * embedding_scale
 
182
  else:
 
183
  return self.run(x, time, embedding=embedding, features=features)
 
184
 
185
  return x
186
 
@@ -194,42 +203,45 @@ class StyleTransformerBlock(nn.Module):
194
  style_dim: int,
195
  multiplier: int,
196
  use_rel_pos: bool,
197
- rel_pos_num_buckets: Optional[int] = None,
198
- rel_pos_max_distance: Optional[int] = None,
199
- context_features: Optional[int] = None,
200
  ):
201
  super().__init__()
202
 
203
- self.use_cross_attention = exists(context_features) and context_features > 0
204
-
 
205
  self.attention = StyleAttention(
206
  features=features,
207
  style_dim=style_dim,
208
  num_heads=num_heads,
209
  head_features=head_features,
210
  use_rel_pos=use_rel_pos,
211
- rel_pos_num_buckets=rel_pos_num_buckets,
212
- rel_pos_max_distance=rel_pos_max_distance,
213
  )
214
 
215
  if self.use_cross_attention:
216
- self.cross_attention = StyleAttention(
217
- features=features,
218
- style_dim=style_dim,
219
- num_heads=num_heads,
220
- head_features=head_features,
221
- context_features=context_features,
222
- use_rel_pos=use_rel_pos,
223
- rel_pos_num_buckets=rel_pos_num_buckets,
224
- rel_pos_max_distance=rel_pos_max_distance,
225
- )
 
226
 
227
  self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
 
229
- def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
  x = self.attention(x, s) + x
231
  if self.use_cross_attention:
232
- x = self.cross_attention(x, s, context=context) + x
 
233
  x = self.feed_forward(x) + x
234
  return x
235
 
@@ -241,10 +253,10 @@ class StyleAttention(nn.Module):
241
  style_dim: int,
242
  head_features: int,
243
  num_heads: int,
244
- context_features: Optional[int] = None,
245
  use_rel_pos: bool,
246
- rel_pos_num_buckets: Optional[int] = None,
247
- rel_pos_max_distance: Optional[int] = None,
248
  ):
249
  super().__init__()
250
  self.context_features = context_features
@@ -264,15 +276,16 @@ class StyleAttention(nn.Module):
264
  num_heads=num_heads,
265
  head_features=head_features,
266
  use_rel_pos=use_rel_pos,
267
- rel_pos_num_buckets=rel_pos_num_buckets,
268
- rel_pos_max_distance=rel_pos_max_distance,
269
  )
270
 
271
- def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
- assert_message = "You must provide a context when using context_features"
273
- assert not self.context_features or exists(context), assert_message
274
  # Use context if provided
275
  context = default(context, x)
 
276
  # Normalize then compute q from input and k,v from context
277
  x, context = self.norm(x, s), self.norm_context(context, s)
278
 
@@ -280,7 +293,9 @@ class StyleAttention(nn.Module):
280
  # Compute and return attention
281
  return self.attention(q, k, v)
282
 
283
- def FeedForward(features: int, multiplier: int) -> nn.Module:
 
 
284
  mid_features = features * multiplier
285
  return nn.Sequential(
286
  nn.Linear(in_features=features, out_features=mid_features),
@@ -292,14 +307,14 @@ def FeedForward(features: int, multiplier: int) -> nn.Module:
292
  class AttentionBase(nn.Module):
293
  def __init__(
294
  self,
295
- features: int,
296
  *,
297
- head_features: int,
298
- num_heads: int,
299
- use_rel_pos: bool,
300
- out_features: Optional[int] = None,
301
- rel_pos_num_buckets: Optional[int] = None,
302
- rel_pos_max_distance: Optional[int] = None,
303
  ):
304
  super().__init__()
305
  self.scale = head_features ** -0.5
@@ -320,7 +335,11 @@ class AttentionBase(nn.Module):
320
  q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
321
  # Compute similarity matrix
322
  sim = einsum("... n d, ... m d -> ... n m", q, k)
323
- sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
 
 
 
 
324
  sim = sim * self.scale
325
  # Get attention matrix with softmax
326
  attn = sim.softmax(dim=-1)
@@ -333,15 +352,15 @@ class AttentionBase(nn.Module):
333
  class Attention(nn.Module):
334
  def __init__(
335
  self,
336
- features: int,
337
  *,
338
- head_features: int,
339
- num_heads: int,
340
- out_features: Optional[int] = None,
341
- context_features: Optional[int] = None,
342
- use_rel_pos: bool,
343
- rel_pos_num_buckets: Optional[int] = None,
344
- rel_pos_max_distance: Optional[int] = None,
345
  ):
346
  super().__init__()
347
  self.context_features = context_features
@@ -363,13 +382,13 @@ class Attention(nn.Module):
363
  num_heads=num_heads,
364
  head_features=head_features,
365
  use_rel_pos=use_rel_pos,
366
- rel_pos_num_buckets=rel_pos_num_buckets,
367
- rel_pos_max_distance=rel_pos_max_distance,
368
  )
369
 
370
- def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
371
- assert_message = "You must provide a context when using context_features"
372
- assert not self.context_features or exists(context), assert_message
373
  # Use context if provided
374
  context = default(context, x)
375
  # Normalize then compute q from input and k,v from context
 
1
  from math import floor, log, pi
2
+ import torch.nn.functional as F
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from einops import rearrange, reduce, repeat
 
8
  from torch import Tensor, einsum
9
 
10
 
11
+ def default(val, d):
12
+ if val is not None: #exists(val):
13
+ return val
14
+ return d # d() if isfunction(d) else d
15
 
16
  class AdaLayerNorm(nn.Module):
17
  def __init__(self, style_dim, channels, eps=1e-5):
 
36
  return x.transpose(1, -1).transpose(-1, -2)
37
 
38
  class StyleTransformer1d(nn.Module):
39
+
40
+ # artificial_stylets / models.py
41
+
42
  def __init__(
43
  self,
44
  num_layers: int,
 
49
  use_context_time: bool = True,
50
  use_rel_pos: bool = False,
51
  context_features_multiplier: int = 1,
52
+ # rel_pos_num_buckets: Optional[int] = None,
53
+ # rel_pos_max_distance: Optional[int] = None,
54
+ context_features=None,
55
+ context_embedding_features=None,
56
+ embedding_max_length=512,
57
  ):
58
  super().__init__()
59
+
60
  self.blocks = nn.ModuleList(
61
  [
62
  StyleTransformerBlock(
 
66
  multiplier=multiplier,
67
  style_dim=context_features,
68
  use_rel_pos=use_rel_pos,
69
+ # rel_pos_num_buckets=rel_pos_num_buckets,
70
+ # rel_pos_max_distance=rel_pos_max_distance,
71
  )
72
  for i in range(num_layers)
73
  ]
 
82
  ),
83
  )
84
 
85
+ use_context_features = context_features is not None
86
  self.use_context_features = use_context_features
87
  self.use_context_time = use_context_time
88
 
89
  if use_context_time or use_context_features:
90
+ # print(f'{use_context_time=} {use_context_features=}ooooooooooooooooooooooooooooooooooo')
91
+ # raise ValueError
92
+ # True True both context
93
  context_mapping_features = channels + context_embedding_features
94
 
95
  self.to_mapping = nn.Sequential(
 
100
  )
101
 
102
  if use_context_time:
103
+
104
  self.to_time = nn.Sequential(
105
  TimePositionalEmbedding(
106
  dim=channels, out_features=context_mapping_features
 
109
  )
110
 
111
  if use_context_features:
112
+
113
  self.to_features = nn.Sequential(
114
  nn.Linear(
115
  in_features=context_features, out_features=context_mapping_features
 
123
 
124
 
125
  def get_mapping(
126
+ self,
127
+ time=None,
128
+ features=None):
129
  """Combines context time features and features into mapping"""
130
  items, mapping = [], None
131
  # Compute time features
132
  if self.use_context_time:
133
+
 
134
  items += [self.to_time(time)]
135
  # Compute features
136
  if self.use_context_features:
137
+
 
138
  items += [self.to_features(features)]
139
 
140
  # Compute joint mapping
141
  if self.use_context_time or self.use_context_features:
142
+ # raise ValueError
143
  mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
144
  mapping = self.to_mapping(mapping)
145
 
 
164
  def forward(self, x: Tensor,
165
  time: Tensor,
166
  embedding_mask_proba: float = 0.0,
167
+ embedding= None,
168
+ features = None,
169
  embedding_scale: float = 1.0) -> Tensor:
170
 
171
  b, device = embedding.shape[0], embedding.device
 
178
  embedding = torch.where(batch_mask, fixed_embedding, embedding)
179
 
180
  if embedding_scale != 1.0:
181
+
182
+
183
  out = self.run(x, time, embedding=embedding, features=features)
184
  out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
185
+
186
+ raise ValueError
187
  return out_masked + (out - out_masked) * embedding_scale
188
+
189
  else:
190
+ # raise ValueError
191
  return self.run(x, time, embedding=embedding, features=features)
192
+
193
 
194
  return x
195
 
 
203
  style_dim: int,
204
  multiplier: int,
205
  use_rel_pos: bool,
206
+ # rel_pos_num_buckets: Optional[int] = None,
207
+ # rel_pos_max_distance: Optional[int] = None,
208
+ context_features = None,
209
  ):
210
  super().__init__()
211
 
212
+ self.use_cross_attention = (context_features is not None) and (context_features > 0)
213
+ # print(f'{rel_pos_num_buckets=} {rel_pos_max_distance=}') # None None
214
+ # raise ValueError
215
  self.attention = StyleAttention(
216
  features=features,
217
  style_dim=style_dim,
218
  num_heads=num_heads,
219
  head_features=head_features,
220
  use_rel_pos=use_rel_pos,
221
+ # rel_pos_num_buckets=rel_pos_num_buckets,
222
+ # rel_pos_max_distance=rel_pos_max_distance,
223
  )
224
 
225
  if self.use_cross_attention:
226
+ raise ValueError
227
+ # self.cross_attention = StyleAttention(
228
+ # features=features,
229
+ # style_dim=style_dim,
230
+ # num_heads=num_heads,
231
+ # head_features=head_features,
232
+ # context_features=context_features,
233
+ # use_rel_pos=use_rel_pos,
234
+ # rel_pos_num_buckets=rel_pos_num_buckets,
235
+ # rel_pos_max_distance=rel_pos_max_distance,
236
+ # )
237
 
238
  self.feed_forward = FeedForward(features=features, multiplier=multiplier)
239
 
240
+ def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor:
241
  x = self.attention(x, s) + x
242
  if self.use_cross_attention:
243
+ raise ValueError
244
+ # x = self.cross_attention(x, s, context=context) + x
245
  x = self.feed_forward(x) + x
246
  return x
247
 
 
253
  style_dim: int,
254
  head_features: int,
255
  num_heads: int,
256
+ context_features = None,
257
  use_rel_pos: bool,
258
+ # rel_pos_num_buckets: Optional[int] = None,
259
+ # rel_pos_max_distance: Optional[int] = None,
260
  ):
261
  super().__init__()
262
  self.context_features = context_features
 
276
  num_heads=num_heads,
277
  head_features=head_features,
278
  use_rel_pos=use_rel_pos,
279
+ # rel_pos_num_buckets=rel_pos_num_buckets,
280
+ # rel_pos_max_distance=rel_pos_max_distance,
281
  )
282
 
283
+ def forward(self, x: Tensor, s: Tensor, *, context = None):
284
+
285
+ # raise ValueError
286
  # Use context if provided
287
  context = default(context, x)
288
+ # print(context.shape,'ppppppppppppppppppppppppppppppppppppppppppp') # bs, time, 1024
289
  # Normalize then compute q from input and k,v from context
290
  x, context = self.norm(x, s), self.norm_context(context, s)
291
 
 
293
  # Compute and return attention
294
  return self.attention(q, k, v)
295
 
296
+
297
+ def FeedForward(features,
298
+ multiplier):
299
  mid_features = features * multiplier
300
  return nn.Sequential(
301
  nn.Linear(in_features=features, out_features=mid_features),
 
307
  class AttentionBase(nn.Module):
308
  def __init__(
309
  self,
310
+ features,
311
  *,
312
+ head_features,
313
+ num_heads,
314
+ use_rel_pos,
315
+ out_features = None,
316
+ # rel_pos_num_buckets: Optional[int] = None,
317
+ # rel_pos_max_distance: Optional[int] = None,
318
  ):
319
  super().__init__()
320
  self.scale = head_features ** -0.5
 
335
  q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
336
  # Compute similarity matrix
337
  sim = einsum("... n d, ... m d -> ... n m", q, k)
338
+
339
+ # _____THERE_IS_NO_rel_po
340
+ # sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
341
+ # print(self.rel_pos)
342
+
343
  sim = sim * self.scale
344
  # Get attention matrix with softmax
345
  attn = sim.softmax(dim=-1)
 
352
  class Attention(nn.Module):
353
  def __init__(
354
  self,
355
+ features,
356
  *,
357
+ head_features,
358
+ num_heads,
359
+ out_features=None,
360
+ context_features=None,
361
+ use_rel_pos,
362
+ # rel_pos_num_buckets: Optional[int] = None,
363
+ # rel_pos_max_distance: Optional[int] = None,
364
  ):
365
  super().__init__()
366
  self.context_features = context_features
 
382
  num_heads=num_heads,
383
  head_features=head_features,
384
  use_rel_pos=use_rel_pos,
385
+ # rel_pos_num_buckets=rel_pos_num_buckets,
386
+ # rel_pos_max_distance=rel_pos_max_distance,
387
  )
388
 
389
+ def forward(self, x: Tensor, *, context = None) -> Tensor:
390
+ # assert_message = "You must provide a context when using context_features"
391
+ # assert not self.context_features or exists(context), assert_message
392
  # Use context if provided
393
  context = default(context, x)
394
  # Normalize then compute q from input and k,v from context
Modules/diffusion/sampler.py CHANGED
@@ -1,11 +1,59 @@
1
  from math import atan, cos, pi, sin, sqrt
2
- from typing import Any, Callable, List, Optional, Tuple, Type
3
- import torch
4
  import torch.nn as nn
5
- import torch.nn.functional as F
6
  from einops import rearrange
7
  from torch import Tensor
8
- from .utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class LogNormalDistribution():
@@ -29,14 +77,13 @@ class UniformDistribution():
29
  def to_batch(
30
  batch_size: int,
31
  device: torch.device,
32
- x: Optional[float] = None,
33
- xs: Optional[Tensor] = None,
34
- ) -> Tensor:
35
- assert exists(x) ^ exists(xs), "Either x or xs must be provided"
36
  # If x provided use the same for all batch items
37
- if exists(x):
38
  xs = torch.full(size=(batch_size,), fill_value=x).to(device)
39
- assert exists(xs)
40
  return xs
41
 
42
  class KDiffusion(nn.Module):
@@ -58,7 +105,7 @@ class KDiffusion(nn.Module):
58
  self.sigma_distribution = sigma_distribution
59
  self.dynamic_threshold = dynamic_threshold
60
 
61
- def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
62
  sigma_data = self.sigma_data
63
  c_noise = torch.log(sigmas) * 0.25
64
  sigmas = rearrange(sigmas, "b -> b 1 1")
@@ -69,9 +116,9 @@ class KDiffusion(nn.Module):
69
 
70
  def denoise_fn(
71
  self,
72
- x_noisy: Tensor,
73
- sigmas: Optional[Tensor] = None,
74
- sigma: Optional[float] = None,
75
  **kwargs,
76
  ):
77
  # raise ValueError
@@ -107,7 +154,7 @@ class KarrasSchedule(nn.Module):
107
  self.sigma_max = sigma_max
108
  self.rho = rho
109
 
110
- def forward(self, num_steps: int, device: Any) -> Tensor:
111
  rho_inv = 1.0 / self.rho
112
  steps = torch.arange(num_steps, device=device, dtype=torch.float32)
113
  sigmas = (
@@ -118,32 +165,7 @@ class KarrasSchedule(nn.Module):
118
  sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
119
  return sigmas
120
 
121
-
122
- """ Samplers """
123
-
124
-
125
- class Sampler(nn.Module):
126
-
127
-
128
-
129
- def forward(
130
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
131
- ) -> Tensor:
132
- raise NotImplementedError()
133
-
134
- def inpaint(
135
- self,
136
- source: Tensor,
137
- mask: Tensor,
138
- fn: Callable,
139
- sigmas: Tensor,
140
- num_steps: int,
141
- num_resamples: int,
142
- ) -> Tensor:
143
- raise NotImplementedError("Inpainting not available with current sampler")
144
-
145
-
146
- class ADPM2Sampler(Sampler):
147
  """https://www.desmos.com/calculator/jbxjlqd9mb"""
148
 
149
  diffusion_types = [KDiffusion,] # VKDiffusion]
@@ -152,15 +174,17 @@ class ADPM2Sampler(Sampler):
152
  super().__init__()
153
  self.rho = rho
154
 
155
- def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
 
 
156
  r = self.rho
157
  sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
158
  sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
159
  sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
160
  return sigma_up, sigma_down, sigma_mid
161
 
162
- def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
163
- # Sigma steps
164
  sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
165
  # Derivative at sigma (∂x/∂sigma)
166
  d = (x - fn(x, sigma=sigma)) / sigma
@@ -175,7 +199,7 @@ class ADPM2Sampler(Sampler):
175
  return x_next
176
 
177
  def forward(
178
- self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int):
179
  # raise ValueError
180
  x = sigmas[0] * noise
181
  # Denoise to sample
@@ -211,7 +235,7 @@ class DiffusionSampler(nn.Module):
211
  # raise ValueError
212
  device = noise.device
213
  num_steps = default(num_steps, self.num_steps) # type: ignore
214
- assert exists(num_steps), "Parameter `num_steps` must be provided"
215
  # Compute sigmas using schedule
216
  sigmas = self.sigma_schedule(num_steps, device)
217
  # Append additional kwargs to denoise function (used e.g. for conditional unet)
 
1
  from math import atan, cos, pi, sin, sqrt
 
 
2
  import torch.nn as nn
 
3
  from einops import rearrange
4
  from torch import Tensor
5
+
6
+ from functools import reduce
7
+ from inspect import isfunction
8
+ from math import ceil, floor, log2, pi
9
+ # from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch import Generator, Tensor
14
+
15
+
16
+
17
+
18
+
19
+
20
+ # def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
21
+ # return isinstance(obj, list) or isinstance(obj, tuple)
22
+
23
+
24
+ def default(val, d):
25
+ if val is not None: #exists(val):
26
+ return val
27
+ return d #d() if isfunction(d) else d
28
+
29
+
30
+ # def to_list(val: Union[T, Sequence[T]]) -> List[T]:
31
+ # if isinstance(val, tuple):
32
+ # return list(val)
33
+ # if isinstance(val, list):
34
+ # return val
35
+ # return [val] # type: ignore
36
+
37
+
38
+ # def prod(vals: Sequence[int]) -> int:
39
+ # return reduce(lambda x, y: x * y, vals)
40
+
41
+
42
+ def closest_power_2(x: float) -> int:
43
+ exponent = log2(x)
44
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
45
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
46
+ return 2 ** int(exponent_closest)
47
+
48
+ def rand_bool(shape, proba, device = None):
49
+ if proba == 1:
50
+ return torch.ones(shape, device=device, dtype=torch.bool)
51
+ elif proba == 0:
52
+ return torch.zeros(shape, device=device, dtype=torch.bool)
53
+ else:
54
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
55
+
56
+ # ============================= END functions from diffusION.utils
57
 
58
 
59
  class LogNormalDistribution():
 
77
  def to_batch(
78
  batch_size: int,
79
  device: torch.device,
80
+ x = None,
81
+ xs = None):
82
+ # assert exists(x) ^ exists(xs), "Either x or xs must be provided"
 
83
  # If x provided use the same for all batch items
84
+ if x is not None: #exists(x):
85
  xs = torch.full(size=(batch_size,), fill_value=x).to(device)
86
+ # assert exists(xs)
87
  return xs
88
 
89
  class KDiffusion(nn.Module):
 
105
  self.sigma_distribution = sigma_distribution
106
  self.dynamic_threshold = dynamic_threshold
107
 
108
+ def get_scale_weights(self, sigmas):
109
  sigma_data = self.sigma_data
110
  c_noise = torch.log(sigmas) * 0.25
111
  sigmas = rearrange(sigmas, "b -> b 1 1")
 
116
 
117
  def denoise_fn(
118
  self,
119
+ x_noisy,
120
+ sigmas = None,
121
+ sigma = None,
122
  **kwargs,
123
  ):
124
  # raise ValueError
 
154
  self.sigma_max = sigma_max
155
  self.rho = rho
156
 
157
+ def forward(self, num_steps: int, device):
158
  rho_inv = 1.0 / self.rho
159
  steps = torch.arange(num_steps, device=device, dtype=torch.float32)
160
  sigmas = (
 
165
  sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
166
  return sigmas
167
 
168
+ class ADPM2Sampler(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  """https://www.desmos.com/calculator/jbxjlqd9mb"""
170
 
171
  diffusion_types = [KDiffusion,] # VKDiffusion]
 
174
  super().__init__()
175
  self.rho = rho
176
 
177
+ def get_sigmas(self,
178
+ sigma,
179
+ sigma_next):
180
  r = self.rho
181
  sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
182
  sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
183
  sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
184
  return sigma_up, sigma_down, sigma_mid
185
 
186
+ def step(self, x, fn, sigma, sigma_next):
187
+
188
  sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
189
  # Derivative at sigma (∂x/∂sigma)
190
  d = (x - fn(x, sigma=sigma)) / sigma
 
199
  return x_next
200
 
201
  def forward(
202
+ self, noise, fn, sigmas, num_steps):
203
  # raise ValueError
204
  x = sigmas[0] * noise
205
  # Denoise to sample
 
235
  # raise ValueError
236
  device = noise.device
237
  num_steps = default(num_steps, self.num_steps) # type: ignore
238
+
239
  # Compute sigmas using schedule
240
  sigmas = self.sigma_schedule(num_steps, device)
241
  # Append additional kwargs to denoise function (used e.g. for conditional unet)
Modules/diffusion/utils.py DELETED
@@ -1,82 +0,0 @@
1
- from functools import reduce
2
- from inspect import isfunction
3
- from math import ceil, floor, log2, pi
4
- from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
- from torch import Generator, Tensor
10
- from typing_extensions import TypeGuard
11
-
12
- T = TypeVar("T")
13
-
14
-
15
- def exists(val: Optional[T]) -> TypeGuard[T]:
16
- return val is not None
17
-
18
-
19
- def iff(condition: bool, value: T) -> Optional[T]:
20
- return value if condition else None
21
-
22
-
23
- def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
- return isinstance(obj, list) or isinstance(obj, tuple)
25
-
26
-
27
- def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
- if exists(val):
29
- return val
30
- return d() if isfunction(d) else d
31
-
32
-
33
- def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
- if isinstance(val, tuple):
35
- return list(val)
36
- if isinstance(val, list):
37
- return val
38
- return [val] # type: ignore
39
-
40
-
41
- def prod(vals: Sequence[int]) -> int:
42
- return reduce(lambda x, y: x * y, vals)
43
-
44
-
45
- def closest_power_2(x: float) -> int:
46
- exponent = log2(x)
47
- distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
- exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
- return 2 ** int(exponent_closest)
50
-
51
- def rand_bool(shape, proba, device = None):
52
- if proba == 1:
53
- return torch.ones(shape, device=device, dtype=torch.bool)
54
- elif proba == 0:
55
- return torch.zeros(shape, device=device, dtype=torch.bool)
56
- else:
57
- return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
-
59
-
60
- """
61
- Kwargs Utils
62
- """
63
-
64
-
65
- def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
- return_dicts: Tuple[Dict, Dict] = ({}, {})
67
- for key in d.keys():
68
- no_prefix = int(not key.startswith(prefix))
69
- return_dicts[no_prefix][key] = d[key]
70
- return return_dicts
71
-
72
-
73
- def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
- kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
- if keep_prefix:
76
- return kwargs_with_prefix, kwargs
77
- kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
- return kwargs_no_prefix, kwargs
79
-
80
-
81
- def prefix_dict(prefix: str, d: Dict) -> Dict:
82
- return {prefix + str(k): v for k, v in d.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -2,27 +2,96 @@
2
 
3
  import os
4
  import os.path as osp
5
-
6
  import copy
7
  import math
8
-
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
-
15
  from Utils.ASR.models import ASRCNN
16
  from Utils.JDC.model import JDCNet
17
-
18
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
  from Modules.diffusion.modules import StyleTransformer1d
20
- from Modules.diffusion.diffusion import AudioDiffusionConditional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
 
24
- from munch import Munch
25
- import yaml
26
 
27
  class LearnedDownSample(nn.Module):
28
  def __init__(self, layer_type, dim_in):
@@ -561,7 +630,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
561
  channels=args.style_dim*2,
562
  context_features=args.style_dim*2,
563
  )
564
-
565
  diffusion.diffusion = KDiffusion(
566
  net=diffusion.unet,
567
  sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
 
2
 
3
  import os
4
  import os.path as osp
 
5
  import copy
6
  import math
 
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
 
12
  from Utils.ASR.models import ASRCNN
13
  from Utils.JDC.model import JDCNet
 
14
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
15
  from Modules.diffusion.modules import StyleTransformer1d
16
+ # from Modules.diffusion.diffusion import AudioDiffusionConditional
17
+ from munch import Munch
18
+ import yaml
19
+ from math import pi
20
+ from random import randint
21
+ # from typing import Any, Optional, Sequence, Tuple, Union
22
+ import torch
23
+ from einops import rearrange
24
+ from torch import Tensor, nn
25
+ from tqdm import tqdm
26
+ # from Modules.diffusion.utils import *
27
+ # from Modules.diffusion.sampler import *
28
+
29
+
30
+
31
+
32
+ def get_default_model_kwargs():
33
+ return dict(
34
+ channels=128,
35
+ patch_size=16,
36
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
37
+ factors=[4, 4, 4, 2, 2, 2],
38
+ num_blocks=[2, 2, 2, 2, 2, 2],
39
+ attentions=[0, 0, 0, 1, 1, 1, 1],
40
+ attention_heads=8,
41
+ attention_features=64,
42
+ attention_multiplier=2,
43
+ attention_use_rel_pos=False,
44
+ diffusion_type="v",
45
+ diffusion_sigma_distribution=UniformDistribution(),
46
+ )
47
+
48
+
49
+ def get_default_sampling_kwargs():
50
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
51
+
52
+ class AudioDiffusionConditional(nn.Module):
53
+ def __init__(
54
+ self,
55
+ embedding_features: int,
56
+ embedding_max_length: int,
57
+ embedding_mask_proba: float = 0.1,
58
+ **kwargs,
59
+ ):
60
+ self.unet = None
61
+ self.embedding_mask_proba = embedding_mask_proba
62
+ # default_kwargs = dict(
63
+ # **get_default_model_kwargs(),
64
+ # unet_type="cfg",
65
+ # context_embedding_features=embedding_features,
66
+ # context_embedding_max_length=embedding_max_length,
67
+ # )
68
+ super().__init__()
69
+
70
+ def forward(self, *args, **kwargs):
71
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
72
+ return self.diffusion(*args, **{**default_kwargs, **kwargs})
73
+
74
+ # def sample(self, *args, **kwargs):
75
+ # default_kwargs = dict(
76
+ # **get_default_sampling_kwargs(),
77
+ # embedding_scale=5.0,
78
+ # )
79
+ # return super().sample(*args, **{**default_kwargs, **kwargs})
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
 
93
 
94
 
 
 
95
 
96
  class LearnedDownSample(nn.Module):
97
  def __init__(self, layer_type, dim_in):
 
630
  channels=args.style_dim*2,
631
  context_features=args.style_dim*2,
632
  )
633
+ # this initialises self.diffusion for AudioDiffusionConditional
634
  diffusion.diffusion = KDiffusion(
635
  net=diffusion.unet,
636
  sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),