factor diffusion
Browse files- Modules/diffusion/diffusion.py +0 -85
- Modules/diffusion/modules.py +96 -77
- Modules/diffusion/sampler.py +70 -46
- Modules/diffusion/utils.py +0 -82
- models.py +77 -8
Modules/diffusion/diffusion.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
from math import pi
|
2 |
-
from random import randint
|
3 |
-
from typing import Any, Optional, Sequence, Tuple, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from einops import rearrange
|
7 |
-
from torch import Tensor, nn
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
from .utils import *
|
11 |
-
from .sampler import *
|
12 |
-
|
13 |
-
"""
|
14 |
-
Diffusion Classes (generic for 1d data)
|
15 |
-
"""
|
16 |
-
|
17 |
-
|
18 |
-
class Model1d(nn.Module):
|
19 |
-
def __init__(self, unet_type: str = "base", **kwargs):
|
20 |
-
super().__init__()
|
21 |
-
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
22 |
-
self.unet = None
|
23 |
-
self.diffusion = None
|
24 |
-
|
25 |
-
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
26 |
-
return self.diffusion(x, **kwargs)
|
27 |
-
|
28 |
-
def sample(self, *args, **kwargs) -> Tensor:
|
29 |
-
return self.diffusion.sample(*args, **kwargs)
|
30 |
-
|
31 |
-
|
32 |
-
"""
|
33 |
-
Audio Diffusion Classes (specific for 1d audio data)
|
34 |
-
"""
|
35 |
-
|
36 |
-
|
37 |
-
def get_default_model_kwargs():
|
38 |
-
return dict(
|
39 |
-
channels=128,
|
40 |
-
patch_size=16,
|
41 |
-
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
42 |
-
factors=[4, 4, 4, 2, 2, 2],
|
43 |
-
num_blocks=[2, 2, 2, 2, 2, 2],
|
44 |
-
attentions=[0, 0, 0, 1, 1, 1, 1],
|
45 |
-
attention_heads=8,
|
46 |
-
attention_features=64,
|
47 |
-
attention_multiplier=2,
|
48 |
-
attention_use_rel_pos=False,
|
49 |
-
diffusion_type="v",
|
50 |
-
diffusion_sigma_distribution=UniformDistribution(),
|
51 |
-
)
|
52 |
-
|
53 |
-
|
54 |
-
def get_default_sampling_kwargs():
|
55 |
-
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
56 |
-
|
57 |
-
class AudioDiffusionConditional(Model1d):
|
58 |
-
def __init__(
|
59 |
-
self,
|
60 |
-
embedding_features: int,
|
61 |
-
embedding_max_length: int,
|
62 |
-
embedding_mask_proba: float = 0.1,
|
63 |
-
**kwargs,
|
64 |
-
):
|
65 |
-
self.embedding_mask_proba = embedding_mask_proba
|
66 |
-
default_kwargs = dict(
|
67 |
-
**get_default_model_kwargs(),
|
68 |
-
unet_type="cfg",
|
69 |
-
context_embedding_features=embedding_features,
|
70 |
-
context_embedding_max_length=embedding_max_length,
|
71 |
-
)
|
72 |
-
super().__init__(**{**default_kwargs, **kwargs})
|
73 |
-
|
74 |
-
def forward(self, *args, **kwargs):
|
75 |
-
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
76 |
-
return super().forward(*args, **{**default_kwargs, **kwargs})
|
77 |
-
|
78 |
-
def sample(self, *args, **kwargs):
|
79 |
-
default_kwargs = dict(
|
80 |
-
**get_default_sampling_kwargs(),
|
81 |
-
embedding_scale=5.0,
|
82 |
-
)
|
83 |
-
return super().sample(*args, **{**default_kwargs, **kwargs})
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Modules/diffusion/modules.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
from math import floor, log, pi
|
2 |
-
|
3 |
-
|
4 |
-
from .utils import *
|
5 |
-
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from einops import rearrange, reduce, repeat
|
@@ -11,9 +8,10 @@ from einops_exts import rearrange_many
|
|
11 |
from torch import Tensor, einsum
|
12 |
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
class AdaLayerNorm(nn.Module):
|
19 |
def __init__(self, style_dim, channels, eps=1e-5):
|
@@ -38,6 +36,9 @@ class AdaLayerNorm(nn.Module):
|
|
38 |
return x.transpose(1, -1).transpose(-1, -2)
|
39 |
|
40 |
class StyleTransformer1d(nn.Module):
|
|
|
|
|
|
|
41 |
def __init__(
|
42 |
self,
|
43 |
num_layers: int,
|
@@ -48,14 +49,14 @@ class StyleTransformer1d(nn.Module):
|
|
48 |
use_context_time: bool = True,
|
49 |
use_rel_pos: bool = False,
|
50 |
context_features_multiplier: int = 1,
|
51 |
-
rel_pos_num_buckets: Optional[int] = None,
|
52 |
-
rel_pos_max_distance: Optional[int] = None,
|
53 |
-
context_features
|
54 |
-
context_embedding_features
|
55 |
-
embedding_max_length
|
56 |
):
|
57 |
super().__init__()
|
58 |
-
|
59 |
self.blocks = nn.ModuleList(
|
60 |
[
|
61 |
StyleTransformerBlock(
|
@@ -65,8 +66,8 @@ class StyleTransformer1d(nn.Module):
|
|
65 |
multiplier=multiplier,
|
66 |
style_dim=context_features,
|
67 |
use_rel_pos=use_rel_pos,
|
68 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
69 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
70 |
)
|
71 |
for i in range(num_layers)
|
72 |
]
|
@@ -81,11 +82,14 @@ class StyleTransformer1d(nn.Module):
|
|
81 |
),
|
82 |
)
|
83 |
|
84 |
-
use_context_features =
|
85 |
self.use_context_features = use_context_features
|
86 |
self.use_context_time = use_context_time
|
87 |
|
88 |
if use_context_time or use_context_features:
|
|
|
|
|
|
|
89 |
context_mapping_features = channels + context_embedding_features
|
90 |
|
91 |
self.to_mapping = nn.Sequential(
|
@@ -96,7 +100,7 @@ class StyleTransformer1d(nn.Module):
|
|
96 |
)
|
97 |
|
98 |
if use_context_time:
|
99 |
-
|
100 |
self.to_time = nn.Sequential(
|
101 |
TimePositionalEmbedding(
|
102 |
dim=channels, out_features=context_mapping_features
|
@@ -105,7 +109,7 @@ class StyleTransformer1d(nn.Module):
|
|
105 |
)
|
106 |
|
107 |
if use_context_features:
|
108 |
-
|
109 |
self.to_features = nn.Sequential(
|
110 |
nn.Linear(
|
111 |
in_features=context_features, out_features=context_mapping_features
|
@@ -119,23 +123,23 @@ class StyleTransformer1d(nn.Module):
|
|
119 |
|
120 |
|
121 |
def get_mapping(
|
122 |
-
self,
|
123 |
-
|
|
|
124 |
"""Combines context time features and features into mapping"""
|
125 |
items, mapping = [], None
|
126 |
# Compute time features
|
127 |
if self.use_context_time:
|
128 |
-
|
129 |
-
assert exists(time), assert_message
|
130 |
items += [self.to_time(time)]
|
131 |
# Compute features
|
132 |
if self.use_context_features:
|
133 |
-
|
134 |
-
assert exists(features), assert_message
|
135 |
items += [self.to_features(features)]
|
136 |
|
137 |
# Compute joint mapping
|
138 |
if self.use_context_time or self.use_context_features:
|
|
|
139 |
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
140 |
mapping = self.to_mapping(mapping)
|
141 |
|
@@ -160,8 +164,8 @@ class StyleTransformer1d(nn.Module):
|
|
160 |
def forward(self, x: Tensor,
|
161 |
time: Tensor,
|
162 |
embedding_mask_proba: float = 0.0,
|
163 |
-
embedding
|
164 |
-
features
|
165 |
embedding_scale: float = 1.0) -> Tensor:
|
166 |
|
167 |
b, device = embedding.shape[0], embedding.device
|
@@ -174,13 +178,18 @@ class StyleTransformer1d(nn.Module):
|
|
174 |
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
175 |
|
176 |
if embedding_scale != 1.0:
|
177 |
-
|
|
|
178 |
out = self.run(x, time, embedding=embedding, features=features)
|
179 |
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
180 |
-
|
|
|
181 |
return out_masked + (out - out_masked) * embedding_scale
|
|
|
182 |
else:
|
|
|
183 |
return self.run(x, time, embedding=embedding, features=features)
|
|
|
184 |
|
185 |
return x
|
186 |
|
@@ -194,42 +203,45 @@ class StyleTransformerBlock(nn.Module):
|
|
194 |
style_dim: int,
|
195 |
multiplier: int,
|
196 |
use_rel_pos: bool,
|
197 |
-
rel_pos_num_buckets: Optional[int] = None,
|
198 |
-
rel_pos_max_distance: Optional[int] = None,
|
199 |
-
context_features
|
200 |
):
|
201 |
super().__init__()
|
202 |
|
203 |
-
self.use_cross_attention =
|
204 |
-
|
|
|
205 |
self.attention = StyleAttention(
|
206 |
features=features,
|
207 |
style_dim=style_dim,
|
208 |
num_heads=num_heads,
|
209 |
head_features=head_features,
|
210 |
use_rel_pos=use_rel_pos,
|
211 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
212 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
213 |
)
|
214 |
|
215 |
if self.use_cross_attention:
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
226 |
|
227 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
228 |
|
229 |
-
def forward(self, x: Tensor, s: Tensor, *, context
|
230 |
x = self.attention(x, s) + x
|
231 |
if self.use_cross_attention:
|
232 |
-
|
|
|
233 |
x = self.feed_forward(x) + x
|
234 |
return x
|
235 |
|
@@ -241,10 +253,10 @@ class StyleAttention(nn.Module):
|
|
241 |
style_dim: int,
|
242 |
head_features: int,
|
243 |
num_heads: int,
|
244 |
-
context_features
|
245 |
use_rel_pos: bool,
|
246 |
-
rel_pos_num_buckets: Optional[int] = None,
|
247 |
-
rel_pos_max_distance: Optional[int] = None,
|
248 |
):
|
249 |
super().__init__()
|
250 |
self.context_features = context_features
|
@@ -264,15 +276,16 @@ class StyleAttention(nn.Module):
|
|
264 |
num_heads=num_heads,
|
265 |
head_features=head_features,
|
266 |
use_rel_pos=use_rel_pos,
|
267 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
268 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
269 |
)
|
270 |
|
271 |
-
def forward(self, x: Tensor, s: Tensor, *, context
|
272 |
-
|
273 |
-
|
274 |
# Use context if provided
|
275 |
context = default(context, x)
|
|
|
276 |
# Normalize then compute q from input and k,v from context
|
277 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
278 |
|
@@ -280,7 +293,9 @@ class StyleAttention(nn.Module):
|
|
280 |
# Compute and return attention
|
281 |
return self.attention(q, k, v)
|
282 |
|
283 |
-
|
|
|
|
|
284 |
mid_features = features * multiplier
|
285 |
return nn.Sequential(
|
286 |
nn.Linear(in_features=features, out_features=mid_features),
|
@@ -292,14 +307,14 @@ def FeedForward(features: int, multiplier: int) -> nn.Module:
|
|
292 |
class AttentionBase(nn.Module):
|
293 |
def __init__(
|
294 |
self,
|
295 |
-
features
|
296 |
*,
|
297 |
-
head_features
|
298 |
-
num_heads
|
299 |
-
use_rel_pos
|
300 |
-
out_features
|
301 |
-
rel_pos_num_buckets: Optional[int] = None,
|
302 |
-
rel_pos_max_distance: Optional[int] = None,
|
303 |
):
|
304 |
super().__init__()
|
305 |
self.scale = head_features ** -0.5
|
@@ -320,7 +335,11 @@ class AttentionBase(nn.Module):
|
|
320 |
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
321 |
# Compute similarity matrix
|
322 |
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
sim = sim * self.scale
|
325 |
# Get attention matrix with softmax
|
326 |
attn = sim.softmax(dim=-1)
|
@@ -333,15 +352,15 @@ class AttentionBase(nn.Module):
|
|
333 |
class Attention(nn.Module):
|
334 |
def __init__(
|
335 |
self,
|
336 |
-
features
|
337 |
*,
|
338 |
-
head_features
|
339 |
-
num_heads
|
340 |
-
out_features
|
341 |
-
context_features
|
342 |
-
use_rel_pos
|
343 |
-
rel_pos_num_buckets: Optional[int] = None,
|
344 |
-
rel_pos_max_distance: Optional[int] = None,
|
345 |
):
|
346 |
super().__init__()
|
347 |
self.context_features = context_features
|
@@ -363,13 +382,13 @@ class Attention(nn.Module):
|
|
363 |
num_heads=num_heads,
|
364 |
head_features=head_features,
|
365 |
use_rel_pos=use_rel_pos,
|
366 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
367 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
368 |
)
|
369 |
|
370 |
-
def forward(self, x: Tensor, *, context
|
371 |
-
assert_message = "You must provide a context when using context_features"
|
372 |
-
assert not self.context_features or exists(context), assert_message
|
373 |
# Use context if provided
|
374 |
context = default(context, x)
|
375 |
# Normalize then compute q from input and k,v from context
|
|
|
1 |
from math import floor, log, pi
|
2 |
+
import torch.nn.functional as F
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from einops import rearrange, reduce, repeat
|
|
|
8 |
from torch import Tensor, einsum
|
9 |
|
10 |
|
11 |
+
def default(val, d):
|
12 |
+
if val is not None: #exists(val):
|
13 |
+
return val
|
14 |
+
return d # d() if isfunction(d) else d
|
15 |
|
16 |
class AdaLayerNorm(nn.Module):
|
17 |
def __init__(self, style_dim, channels, eps=1e-5):
|
|
|
36 |
return x.transpose(1, -1).transpose(-1, -2)
|
37 |
|
38 |
class StyleTransformer1d(nn.Module):
|
39 |
+
|
40 |
+
# artificial_stylets / models.py
|
41 |
+
|
42 |
def __init__(
|
43 |
self,
|
44 |
num_layers: int,
|
|
|
49 |
use_context_time: bool = True,
|
50 |
use_rel_pos: bool = False,
|
51 |
context_features_multiplier: int = 1,
|
52 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
53 |
+
# rel_pos_max_distance: Optional[int] = None,
|
54 |
+
context_features=None,
|
55 |
+
context_embedding_features=None,
|
56 |
+
embedding_max_length=512,
|
57 |
):
|
58 |
super().__init__()
|
59 |
+
|
60 |
self.blocks = nn.ModuleList(
|
61 |
[
|
62 |
StyleTransformerBlock(
|
|
|
66 |
multiplier=multiplier,
|
67 |
style_dim=context_features,
|
68 |
use_rel_pos=use_rel_pos,
|
69 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
70 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
71 |
)
|
72 |
for i in range(num_layers)
|
73 |
]
|
|
|
82 |
),
|
83 |
)
|
84 |
|
85 |
+
use_context_features = context_features is not None
|
86 |
self.use_context_features = use_context_features
|
87 |
self.use_context_time = use_context_time
|
88 |
|
89 |
if use_context_time or use_context_features:
|
90 |
+
# print(f'{use_context_time=} {use_context_features=}ooooooooooooooooooooooooooooooooooo')
|
91 |
+
# raise ValueError
|
92 |
+
# True True both context
|
93 |
context_mapping_features = channels + context_embedding_features
|
94 |
|
95 |
self.to_mapping = nn.Sequential(
|
|
|
100 |
)
|
101 |
|
102 |
if use_context_time:
|
103 |
+
|
104 |
self.to_time = nn.Sequential(
|
105 |
TimePositionalEmbedding(
|
106 |
dim=channels, out_features=context_mapping_features
|
|
|
109 |
)
|
110 |
|
111 |
if use_context_features:
|
112 |
+
|
113 |
self.to_features = nn.Sequential(
|
114 |
nn.Linear(
|
115 |
in_features=context_features, out_features=context_mapping_features
|
|
|
123 |
|
124 |
|
125 |
def get_mapping(
|
126 |
+
self,
|
127 |
+
time=None,
|
128 |
+
features=None):
|
129 |
"""Combines context time features and features into mapping"""
|
130 |
items, mapping = [], None
|
131 |
# Compute time features
|
132 |
if self.use_context_time:
|
133 |
+
|
|
|
134 |
items += [self.to_time(time)]
|
135 |
# Compute features
|
136 |
if self.use_context_features:
|
137 |
+
|
|
|
138 |
items += [self.to_features(features)]
|
139 |
|
140 |
# Compute joint mapping
|
141 |
if self.use_context_time or self.use_context_features:
|
142 |
+
# raise ValueError
|
143 |
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
144 |
mapping = self.to_mapping(mapping)
|
145 |
|
|
|
164 |
def forward(self, x: Tensor,
|
165 |
time: Tensor,
|
166 |
embedding_mask_proba: float = 0.0,
|
167 |
+
embedding= None,
|
168 |
+
features = None,
|
169 |
embedding_scale: float = 1.0) -> Tensor:
|
170 |
|
171 |
b, device = embedding.shape[0], embedding.device
|
|
|
178 |
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
179 |
|
180 |
if embedding_scale != 1.0:
|
181 |
+
|
182 |
+
|
183 |
out = self.run(x, time, embedding=embedding, features=features)
|
184 |
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
185 |
+
|
186 |
+
raise ValueError
|
187 |
return out_masked + (out - out_masked) * embedding_scale
|
188 |
+
|
189 |
else:
|
190 |
+
# raise ValueError
|
191 |
return self.run(x, time, embedding=embedding, features=features)
|
192 |
+
|
193 |
|
194 |
return x
|
195 |
|
|
|
203 |
style_dim: int,
|
204 |
multiplier: int,
|
205 |
use_rel_pos: bool,
|
206 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
207 |
+
# rel_pos_max_distance: Optional[int] = None,
|
208 |
+
context_features = None,
|
209 |
):
|
210 |
super().__init__()
|
211 |
|
212 |
+
self.use_cross_attention = (context_features is not None) and (context_features > 0)
|
213 |
+
# print(f'{rel_pos_num_buckets=} {rel_pos_max_distance=}') # None None
|
214 |
+
# raise ValueError
|
215 |
self.attention = StyleAttention(
|
216 |
features=features,
|
217 |
style_dim=style_dim,
|
218 |
num_heads=num_heads,
|
219 |
head_features=head_features,
|
220 |
use_rel_pos=use_rel_pos,
|
221 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
222 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
223 |
)
|
224 |
|
225 |
if self.use_cross_attention:
|
226 |
+
raise ValueError
|
227 |
+
# self.cross_attention = StyleAttention(
|
228 |
+
# features=features,
|
229 |
+
# style_dim=style_dim,
|
230 |
+
# num_heads=num_heads,
|
231 |
+
# head_features=head_features,
|
232 |
+
# context_features=context_features,
|
233 |
+
# use_rel_pos=use_rel_pos,
|
234 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
235 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
236 |
+
# )
|
237 |
|
238 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
239 |
|
240 |
+
def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor:
|
241 |
x = self.attention(x, s) + x
|
242 |
if self.use_cross_attention:
|
243 |
+
raise ValueError
|
244 |
+
# x = self.cross_attention(x, s, context=context) + x
|
245 |
x = self.feed_forward(x) + x
|
246 |
return x
|
247 |
|
|
|
253 |
style_dim: int,
|
254 |
head_features: int,
|
255 |
num_heads: int,
|
256 |
+
context_features = None,
|
257 |
use_rel_pos: bool,
|
258 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
259 |
+
# rel_pos_max_distance: Optional[int] = None,
|
260 |
):
|
261 |
super().__init__()
|
262 |
self.context_features = context_features
|
|
|
276 |
num_heads=num_heads,
|
277 |
head_features=head_features,
|
278 |
use_rel_pos=use_rel_pos,
|
279 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
280 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
281 |
)
|
282 |
|
283 |
+
def forward(self, x: Tensor, s: Tensor, *, context = None):
|
284 |
+
|
285 |
+
# raise ValueError
|
286 |
# Use context if provided
|
287 |
context = default(context, x)
|
288 |
+
# print(context.shape,'ppppppppppppppppppppppppppppppppppppppppppp') # bs, time, 1024
|
289 |
# Normalize then compute q from input and k,v from context
|
290 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
291 |
|
|
|
293 |
# Compute and return attention
|
294 |
return self.attention(q, k, v)
|
295 |
|
296 |
+
|
297 |
+
def FeedForward(features,
|
298 |
+
multiplier):
|
299 |
mid_features = features * multiplier
|
300 |
return nn.Sequential(
|
301 |
nn.Linear(in_features=features, out_features=mid_features),
|
|
|
307 |
class AttentionBase(nn.Module):
|
308 |
def __init__(
|
309 |
self,
|
310 |
+
features,
|
311 |
*,
|
312 |
+
head_features,
|
313 |
+
num_heads,
|
314 |
+
use_rel_pos,
|
315 |
+
out_features = None,
|
316 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
317 |
+
# rel_pos_max_distance: Optional[int] = None,
|
318 |
):
|
319 |
super().__init__()
|
320 |
self.scale = head_features ** -0.5
|
|
|
335 |
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
336 |
# Compute similarity matrix
|
337 |
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
338 |
+
|
339 |
+
# _____THERE_IS_NO_rel_po
|
340 |
+
# sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
|
341 |
+
# print(self.rel_pos)
|
342 |
+
|
343 |
sim = sim * self.scale
|
344 |
# Get attention matrix with softmax
|
345 |
attn = sim.softmax(dim=-1)
|
|
|
352 |
class Attention(nn.Module):
|
353 |
def __init__(
|
354 |
self,
|
355 |
+
features,
|
356 |
*,
|
357 |
+
head_features,
|
358 |
+
num_heads,
|
359 |
+
out_features=None,
|
360 |
+
context_features=None,
|
361 |
+
use_rel_pos,
|
362 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
363 |
+
# rel_pos_max_distance: Optional[int] = None,
|
364 |
):
|
365 |
super().__init__()
|
366 |
self.context_features = context_features
|
|
|
382 |
num_heads=num_heads,
|
383 |
head_features=head_features,
|
384 |
use_rel_pos=use_rel_pos,
|
385 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
386 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
387 |
)
|
388 |
|
389 |
+
def forward(self, x: Tensor, *, context = None) -> Tensor:
|
390 |
+
# assert_message = "You must provide a context when using context_features"
|
391 |
+
# assert not self.context_features or exists(context), assert_message
|
392 |
# Use context if provided
|
393 |
context = default(context, x)
|
394 |
# Normalize then compute q from input and k,v from context
|
Modules/diffusion/sampler.py
CHANGED
@@ -1,11 +1,59 @@
|
|
1 |
from math import atan, cos, pi, sin, sqrt
|
2 |
-
from typing import Any, Callable, List, Optional, Tuple, Type
|
3 |
-
import torch
|
4 |
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from torch import Tensor
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class LogNormalDistribution():
|
@@ -29,14 +77,13 @@ class UniformDistribution():
|
|
29 |
def to_batch(
|
30 |
batch_size: int,
|
31 |
device: torch.device,
|
32 |
-
x
|
33 |
-
xs
|
34 |
-
)
|
35 |
-
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
36 |
# If x provided use the same for all batch items
|
37 |
-
if exists(x):
|
38 |
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
39 |
-
assert exists(xs)
|
40 |
return xs
|
41 |
|
42 |
class KDiffusion(nn.Module):
|
@@ -58,7 +105,7 @@ class KDiffusion(nn.Module):
|
|
58 |
self.sigma_distribution = sigma_distribution
|
59 |
self.dynamic_threshold = dynamic_threshold
|
60 |
|
61 |
-
def get_scale_weights(self, sigmas
|
62 |
sigma_data = self.sigma_data
|
63 |
c_noise = torch.log(sigmas) * 0.25
|
64 |
sigmas = rearrange(sigmas, "b -> b 1 1")
|
@@ -69,9 +116,9 @@ class KDiffusion(nn.Module):
|
|
69 |
|
70 |
def denoise_fn(
|
71 |
self,
|
72 |
-
x_noisy
|
73 |
-
sigmas
|
74 |
-
sigma
|
75 |
**kwargs,
|
76 |
):
|
77 |
# raise ValueError
|
@@ -107,7 +154,7 @@ class KarrasSchedule(nn.Module):
|
|
107 |
self.sigma_max = sigma_max
|
108 |
self.rho = rho
|
109 |
|
110 |
-
def forward(self, num_steps: int, device
|
111 |
rho_inv = 1.0 / self.rho
|
112 |
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
113 |
sigmas = (
|
@@ -118,32 +165,7 @@ class KarrasSchedule(nn.Module):
|
|
118 |
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
119 |
return sigmas
|
120 |
|
121 |
-
|
122 |
-
""" Samplers """
|
123 |
-
|
124 |
-
|
125 |
-
class Sampler(nn.Module):
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
def forward(
|
130 |
-
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
131 |
-
) -> Tensor:
|
132 |
-
raise NotImplementedError()
|
133 |
-
|
134 |
-
def inpaint(
|
135 |
-
self,
|
136 |
-
source: Tensor,
|
137 |
-
mask: Tensor,
|
138 |
-
fn: Callable,
|
139 |
-
sigmas: Tensor,
|
140 |
-
num_steps: int,
|
141 |
-
num_resamples: int,
|
142 |
-
) -> Tensor:
|
143 |
-
raise NotImplementedError("Inpainting not available with current sampler")
|
144 |
-
|
145 |
-
|
146 |
-
class ADPM2Sampler(Sampler):
|
147 |
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
148 |
|
149 |
diffusion_types = [KDiffusion,] # VKDiffusion]
|
@@ -152,15 +174,17 @@ class ADPM2Sampler(Sampler):
|
|
152 |
super().__init__()
|
153 |
self.rho = rho
|
154 |
|
155 |
-
def get_sigmas(self,
|
|
|
|
|
156 |
r = self.rho
|
157 |
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
158 |
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
159 |
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
160 |
return sigma_up, sigma_down, sigma_mid
|
161 |
|
162 |
-
def step(self, x
|
163 |
-
|
164 |
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
165 |
# Derivative at sigma (∂x/∂sigma)
|
166 |
d = (x - fn(x, sigma=sigma)) / sigma
|
@@ -175,7 +199,7 @@ class ADPM2Sampler(Sampler):
|
|
175 |
return x_next
|
176 |
|
177 |
def forward(
|
178 |
-
self, noise
|
179 |
# raise ValueError
|
180 |
x = sigmas[0] * noise
|
181 |
# Denoise to sample
|
@@ -211,7 +235,7 @@ class DiffusionSampler(nn.Module):
|
|
211 |
# raise ValueError
|
212 |
device = noise.device
|
213 |
num_steps = default(num_steps, self.num_steps) # type: ignore
|
214 |
-
|
215 |
# Compute sigmas using schedule
|
216 |
sigmas = self.sigma_schedule(num_steps, device)
|
217 |
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
|
|
1 |
from math import atan, cos, pi, sin, sqrt
|
|
|
|
|
2 |
import torch.nn as nn
|
|
|
3 |
from einops import rearrange
|
4 |
from torch import Tensor
|
5 |
+
|
6 |
+
from functools import reduce
|
7 |
+
from inspect import isfunction
|
8 |
+
from math import ceil, floor, log2, pi
|
9 |
+
# from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange
|
13 |
+
from torch import Generator, Tensor
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
21 |
+
# return isinstance(obj, list) or isinstance(obj, tuple)
|
22 |
+
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
if val is not None: #exists(val):
|
26 |
+
return val
|
27 |
+
return d #d() if isfunction(d) else d
|
28 |
+
|
29 |
+
|
30 |
+
# def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
31 |
+
# if isinstance(val, tuple):
|
32 |
+
# return list(val)
|
33 |
+
# if isinstance(val, list):
|
34 |
+
# return val
|
35 |
+
# return [val] # type: ignore
|
36 |
+
|
37 |
+
|
38 |
+
# def prod(vals: Sequence[int]) -> int:
|
39 |
+
# return reduce(lambda x, y: x * y, vals)
|
40 |
+
|
41 |
+
|
42 |
+
def closest_power_2(x: float) -> int:
|
43 |
+
exponent = log2(x)
|
44 |
+
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
45 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
46 |
+
return 2 ** int(exponent_closest)
|
47 |
+
|
48 |
+
def rand_bool(shape, proba, device = None):
|
49 |
+
if proba == 1:
|
50 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
51 |
+
elif proba == 0:
|
52 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
53 |
+
else:
|
54 |
+
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
55 |
+
|
56 |
+
# ============================= END functions from diffusION.utils
|
57 |
|
58 |
|
59 |
class LogNormalDistribution():
|
|
|
77 |
def to_batch(
|
78 |
batch_size: int,
|
79 |
device: torch.device,
|
80 |
+
x = None,
|
81 |
+
xs = None):
|
82 |
+
# assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
|
|
83 |
# If x provided use the same for all batch items
|
84 |
+
if x is not None: #exists(x):
|
85 |
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
86 |
+
# assert exists(xs)
|
87 |
return xs
|
88 |
|
89 |
class KDiffusion(nn.Module):
|
|
|
105 |
self.sigma_distribution = sigma_distribution
|
106 |
self.dynamic_threshold = dynamic_threshold
|
107 |
|
108 |
+
def get_scale_weights(self, sigmas):
|
109 |
sigma_data = self.sigma_data
|
110 |
c_noise = torch.log(sigmas) * 0.25
|
111 |
sigmas = rearrange(sigmas, "b -> b 1 1")
|
|
|
116 |
|
117 |
def denoise_fn(
|
118 |
self,
|
119 |
+
x_noisy,
|
120 |
+
sigmas = None,
|
121 |
+
sigma = None,
|
122 |
**kwargs,
|
123 |
):
|
124 |
# raise ValueError
|
|
|
154 |
self.sigma_max = sigma_max
|
155 |
self.rho = rho
|
156 |
|
157 |
+
def forward(self, num_steps: int, device):
|
158 |
rho_inv = 1.0 / self.rho
|
159 |
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
160 |
sigmas = (
|
|
|
165 |
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
166 |
return sigmas
|
167 |
|
168 |
+
class ADPM2Sampler(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
170 |
|
171 |
diffusion_types = [KDiffusion,] # VKDiffusion]
|
|
|
174 |
super().__init__()
|
175 |
self.rho = rho
|
176 |
|
177 |
+
def get_sigmas(self,
|
178 |
+
sigma,
|
179 |
+
sigma_next):
|
180 |
r = self.rho
|
181 |
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
182 |
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
183 |
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
184 |
return sigma_up, sigma_down, sigma_mid
|
185 |
|
186 |
+
def step(self, x, fn, sigma, sigma_next):
|
187 |
+
|
188 |
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
189 |
# Derivative at sigma (∂x/∂sigma)
|
190 |
d = (x - fn(x, sigma=sigma)) / sigma
|
|
|
199 |
return x_next
|
200 |
|
201 |
def forward(
|
202 |
+
self, noise, fn, sigmas, num_steps):
|
203 |
# raise ValueError
|
204 |
x = sigmas[0] * noise
|
205 |
# Denoise to sample
|
|
|
235 |
# raise ValueError
|
236 |
device = noise.device
|
237 |
num_steps = default(num_steps, self.num_steps) # type: ignore
|
238 |
+
|
239 |
# Compute sigmas using schedule
|
240 |
sigmas = self.sigma_schedule(num_steps, device)
|
241 |
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
Modules/diffusion/utils.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
from functools import reduce
|
2 |
-
from inspect import isfunction
|
3 |
-
from math import ceil, floor, log2, pi
|
4 |
-
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from einops import rearrange
|
9 |
-
from torch import Generator, Tensor
|
10 |
-
from typing_extensions import TypeGuard
|
11 |
-
|
12 |
-
T = TypeVar("T")
|
13 |
-
|
14 |
-
|
15 |
-
def exists(val: Optional[T]) -> TypeGuard[T]:
|
16 |
-
return val is not None
|
17 |
-
|
18 |
-
|
19 |
-
def iff(condition: bool, value: T) -> Optional[T]:
|
20 |
-
return value if condition else None
|
21 |
-
|
22 |
-
|
23 |
-
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
24 |
-
return isinstance(obj, list) or isinstance(obj, tuple)
|
25 |
-
|
26 |
-
|
27 |
-
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
28 |
-
if exists(val):
|
29 |
-
return val
|
30 |
-
return d() if isfunction(d) else d
|
31 |
-
|
32 |
-
|
33 |
-
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
34 |
-
if isinstance(val, tuple):
|
35 |
-
return list(val)
|
36 |
-
if isinstance(val, list):
|
37 |
-
return val
|
38 |
-
return [val] # type: ignore
|
39 |
-
|
40 |
-
|
41 |
-
def prod(vals: Sequence[int]) -> int:
|
42 |
-
return reduce(lambda x, y: x * y, vals)
|
43 |
-
|
44 |
-
|
45 |
-
def closest_power_2(x: float) -> int:
|
46 |
-
exponent = log2(x)
|
47 |
-
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
48 |
-
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
49 |
-
return 2 ** int(exponent_closest)
|
50 |
-
|
51 |
-
def rand_bool(shape, proba, device = None):
|
52 |
-
if proba == 1:
|
53 |
-
return torch.ones(shape, device=device, dtype=torch.bool)
|
54 |
-
elif proba == 0:
|
55 |
-
return torch.zeros(shape, device=device, dtype=torch.bool)
|
56 |
-
else:
|
57 |
-
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
58 |
-
|
59 |
-
|
60 |
-
"""
|
61 |
-
Kwargs Utils
|
62 |
-
"""
|
63 |
-
|
64 |
-
|
65 |
-
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
66 |
-
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
67 |
-
for key in d.keys():
|
68 |
-
no_prefix = int(not key.startswith(prefix))
|
69 |
-
return_dicts[no_prefix][key] = d[key]
|
70 |
-
return return_dicts
|
71 |
-
|
72 |
-
|
73 |
-
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
74 |
-
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
75 |
-
if keep_prefix:
|
76 |
-
return kwargs_with_prefix, kwargs
|
77 |
-
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
78 |
-
return kwargs_no_prefix, kwargs
|
79 |
-
|
80 |
-
|
81 |
-
def prefix_dict(prefix: str, d: Dict) -> Dict:
|
82 |
-
return {prefix + str(k): v for k, v in d.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
@@ -2,27 +2,96 @@
|
|
2 |
|
3 |
import os
|
4 |
import os.path as osp
|
5 |
-
|
6 |
import copy
|
7 |
import math
|
8 |
-
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
13 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
14 |
-
|
15 |
from Utils.ASR.models import ASRCNN
|
16 |
from Utils.JDC.model import JDCNet
|
17 |
-
|
18 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
19 |
from Modules.diffusion.modules import StyleTransformer1d
|
20 |
-
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
|
24 |
-
from munch import Munch
|
25 |
-
import yaml
|
26 |
|
27 |
class LearnedDownSample(nn.Module):
|
28 |
def __init__(self, layer_type, dim_in):
|
@@ -561,7 +630,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
561 |
channels=args.style_dim*2,
|
562 |
context_features=args.style_dim*2,
|
563 |
)
|
564 |
-
|
565 |
diffusion.diffusion = KDiffusion(
|
566 |
net=diffusion.unet,
|
567 |
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|
|
|
2 |
|
3 |
import os
|
4 |
import os.path as osp
|
|
|
5 |
import copy
|
6 |
import math
|
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
|
12 |
from Utils.ASR.models import ASRCNN
|
13 |
from Utils.JDC.model import JDCNet
|
|
|
14 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
15 |
from Modules.diffusion.modules import StyleTransformer1d
|
16 |
+
# from Modules.diffusion.diffusion import AudioDiffusionConditional
|
17 |
+
from munch import Munch
|
18 |
+
import yaml
|
19 |
+
from math import pi
|
20 |
+
from random import randint
|
21 |
+
# from typing import Any, Optional, Sequence, Tuple, Union
|
22 |
+
import torch
|
23 |
+
from einops import rearrange
|
24 |
+
from torch import Tensor, nn
|
25 |
+
from tqdm import tqdm
|
26 |
+
# from Modules.diffusion.utils import *
|
27 |
+
# from Modules.diffusion.sampler import *
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def get_default_model_kwargs():
|
33 |
+
return dict(
|
34 |
+
channels=128,
|
35 |
+
patch_size=16,
|
36 |
+
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
37 |
+
factors=[4, 4, 4, 2, 2, 2],
|
38 |
+
num_blocks=[2, 2, 2, 2, 2, 2],
|
39 |
+
attentions=[0, 0, 0, 1, 1, 1, 1],
|
40 |
+
attention_heads=8,
|
41 |
+
attention_features=64,
|
42 |
+
attention_multiplier=2,
|
43 |
+
attention_use_rel_pos=False,
|
44 |
+
diffusion_type="v",
|
45 |
+
diffusion_sigma_distribution=UniformDistribution(),
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def get_default_sampling_kwargs():
|
50 |
+
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
51 |
+
|
52 |
+
class AudioDiffusionConditional(nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
embedding_features: int,
|
56 |
+
embedding_max_length: int,
|
57 |
+
embedding_mask_proba: float = 0.1,
|
58 |
+
**kwargs,
|
59 |
+
):
|
60 |
+
self.unet = None
|
61 |
+
self.embedding_mask_proba = embedding_mask_proba
|
62 |
+
# default_kwargs = dict(
|
63 |
+
# **get_default_model_kwargs(),
|
64 |
+
# unet_type="cfg",
|
65 |
+
# context_embedding_features=embedding_features,
|
66 |
+
# context_embedding_max_length=embedding_max_length,
|
67 |
+
# )
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
def forward(self, *args, **kwargs):
|
71 |
+
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
72 |
+
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
73 |
+
|
74 |
+
# def sample(self, *args, **kwargs):
|
75 |
+
# default_kwargs = dict(
|
76 |
+
# **get_default_sampling_kwargs(),
|
77 |
+
# embedding_scale=5.0,
|
78 |
+
# )
|
79 |
+
# return super().sample(*args, **{**default_kwargs, **kwargs})
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
|
93 |
|
94 |
|
|
|
|
|
95 |
|
96 |
class LearnedDownSample(nn.Module):
|
97 |
def __init__(self, layer_type, dim_in):
|
|
|
630 |
channels=args.style_dim*2,
|
631 |
context_features=args.style_dim*2,
|
632 |
)
|
633 |
+
# this initialises self.diffusion for AudioDiffusionConditional
|
634 |
diffusion.diffusion = KDiffusion(
|
635 |
net=diffusion.unet,
|
636 |
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|