16 |
17 |
# UnlimitedMusicGen
18 |
This is my modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
19 |
20 |
# Audiocraft
21 |

22 |

148 |
## License
149 |
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
150 |
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
72 |
def test_convert_audio_resample(self):
73 |
b, c, dur = 2, 1, 4.
74 |
sr = 3
75 |
new_sr = 2
76 |
audio = get_batch_white_noise(b, c, int(sr * dur))
77 |
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
78 |
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
79 |
assert torch.allclose(out, out_j)
80 |
81 |
82 |
class TestNormalizeAudio:
83 |
84 |
def test_clip_wav(self):
85 |
b, c, dur = 2, 1, 4.
86 |
sr = 3
87 |
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
88 |
89 |
assert audio.abs().max() <= 1
90 |
91 |
def test_normalize_audio_clip(self):
92 |
b, c, dur = 2, 1, 4.
93 |
sr = 3
94 |
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
95 |
norm_audio = normalize_audio(audio, strategy='clip')
96 |
assert norm_audio.abs().max() <= 1
97 |
98 |
def test_normalize_audio_rms(self):
99 |
b, c, dur = 2, 1, 4.
100 |
sr = 3
101 |
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
102 |
norm_audio = normalize_audio(audio, strategy='rms')
103 |
assert norm_audio.abs().max() <= 1
104 |
105 |
def test_normalize_audio_peak(self):
106 |
b, c, dur = 2, 1, 4.
107 |
sr = 3
108 |
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
109 |
norm_audio = normalize_audio(audio, strategy='peak')
110 |
assert norm_audio.abs().max() <= 1
@@ -1,60 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import random
8 |
9 |
import numpy as np
10 |
import torch
11 |
12 |
from audiocraft.models import EncodecModel
13 |
from audiocraft.modules import SEANetEncoder, SEANetDecoder
14 |
from audiocraft.quantization import DummyQuantizer
15 |
16 |
17 |
class TestEncodecModel:
18 |
19 |
def _create_encodec_model(self,
20 |
sample_rate: int,
21 |
channels: int,
22 |
dim: int = 5,
23 |
n_filters: int = 3,
24 |
n_residual_layers: int = 1,
25 |
ratios: list = [5, 4, 3, 2],
26 |
27 |
frame_rate =
28 |
encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
29 |
n_residual_layers=n_residual_layers, ratios=ratios)
30 |
decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
31 |
n_residual_layers=n_residual_layers, ratios=ratios)
32 |
quantizer = DummyQuantizer()
33 |
model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
34 |
sample_rate=sample_rate, channels=channels, **kwargs)
35 |
return model
36 |
37 |
def test_model(self):
38 |
39 |
sample_rate = 24_000
40 |
channels = 1
41 |
model = self._create_encodec_model(sample_rate, channels)
42 |
for _ in range(10):
43 |
length = random.randrange(1, 10_000)
44 |
x = torch.randn(2, channels, length)
45 |
res = model(x)
46 |
assert res.x.shape == x.shape
47 |
48 |
def test_model_renorm(self):
49 |
50 |
sample_rate = 24_000
51 |
channels = 1
52 |
model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
53 |
model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
54 |
55 |
for _ in range(10):
56 |
length = random.randrange(1, 10_000)
57 |
x = torch.randn(2, channels, length)
58 |
codes, scales = model_nonorm.encode(x)
59 |
codes, scales = model_renorm.encode(x)
60 |
assert scales is not None
@@ -1,50 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import pytest
8 |
import torch
9 |
10 |
from audiocraft.models import MusicGen
11 |
12 |
13 |
class TestSEANetModel:
14 |
def get_musicgen(self):
15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
16 |
17 |
return mg
18 |
19 |
def test_base(self):
20 |
mg = self.get_musicgen()
21 |
assert mg.frame_rate == 25
22 |
assert mg.sample_rate == 32000
23 |
assert mg.audio_channels == 1
24 |
25 |
def test_generate_unconditional(self):
26 |
mg = self.get_musicgen()
27 |
wav = mg.generate_unconditional(3)
28 |
assert list(wav.shape) == [3, 1, 64000]
29 |
30 |
def test_generate_continuation(self):
31 |
mg = self.get_musicgen()
32 |
prompt = torch.randn(3, 1, 32000)
33 |
wav = mg.generate_continuation(prompt, 32000)
34 |
assert list(wav.shape) == [3, 1, 64000]
35 |
36 |
prompt = torch.randn(2, 1, 32000)
37 |
wav = mg.generate_continuation(
38 |
prompt, 32000, ['youpi', 'lapin dort'])
39 |
assert list(wav.shape) == [2, 1, 64000]
40 |
41 |
prompt = torch.randn(2, 1, 32000)
42 |
with pytest.raises(AssertionError):
43 |
wav = mg.generate_continuation(
44 |
prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
45 |
46 |
def test_generate(self):
47 |
mg = self.get_musicgen()
48 |
wav = mg.generate(
49 |
['youpi', 'lapin dort'])
50 |
assert list(wav.shape) == [2, 1, 64000]
@@ -1,5 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
@@ -1,246 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import pytest
8 |
import torch
9 |
10 |
from audiocraft.modules.codebooks_patterns import (
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
class TestParallelPatternProvider:
19 |
20 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
21 |
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
22 |
def test_get_pattern(self, n_q: int, timesteps: int):
23 |
provider = ParallelPatternProvider(n_q)
24 |
pattern = provider.get_pattern(timesteps)
25 |
# + 1 to account for 1st step
26 |
assert len(pattern.layout) == timesteps + 1
27 |
28 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
29 |
@pytest.mark.parametrize("timesteps", [8, 16, 100])
30 |
def test_pattern_content(self, n_q: int, timesteps: int):
31 |
provider = ParallelPatternProvider(n_q)
32 |
pattern = provider.get_pattern(timesteps)
33 |
for s, v in enumerate(pattern.layout):
34 |
for i, code in enumerate(v):
35 |
assert i == code.q
36 |
assert code.t == s - 1 # account for the 1st empty step
37 |
38 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
39 |
@pytest.mark.parametrize("timesteps", [8, 16, 100])
40 |
def test_pattern_max_delay(self, n_q: int, timesteps: int):
41 |
provider = ParallelPatternProvider(n_q)
42 |
pattern = provider.get_pattern(timesteps)
43 |
assert pattern.max_delay == 0
44 |
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
45 |
46 |
47 |
class TestDelayedPatternProvider:
48 |
49 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
50 |
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
51 |
def test_get_pattern(self, n_q: int, timesteps: int):
52 |
delays = [
53 |
54 |
[0] + [1] * (n_q - 1),
55 |
[0] + [4] * (n_q - 1),
56 |
57 |
for delay in delays:
58 |
provider = DelayedPatternProvider(n_q, delay)
59 |
pattern = provider.get_pattern(timesteps)
60 |
# + 1 to account for 1st step
61 |
assert len(pattern.layout) == timesteps + max(delay) + 1
62 |
63 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
64 |
@pytest.mark.parametrize("timesteps", [8, 16, 100])
65 |
def test_pattern_content(self, n_q: int, timesteps: int):
66 |
provider = DelayedPatternProvider(n_q)
67 |
pattern = provider.get_pattern(timesteps)
68 |
for s, v in enumerate(pattern.layout):
69 |
for i, code in enumerate(v):
70 |
assert i == code.q
71 |
assert code.t == max(0, s - code.q - 1)
72 |
73 |
@pytest.mark.parametrize("timesteps", [8, 16, 100])
74 |
@pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
75 |
def test_pattern_max_delay(self, timesteps: int, delay: list):
76 |
provider = DelayedPatternProvider(len(delay), delay)
77 |
pattern = provider.get_pattern(timesteps)
78 |
assert pattern.max_delay == max(delay)
79 |
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
80 |
81 |
82 |
class TestUnrolledPatternProvider:
83 |
84 |
@pytest.mark.parametrize("timesteps", [0, 1, 16])
85 |
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
86 |
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
87 |
def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
88 |
n_q = len(flattening)
89 |
max_delay = max(delays)
90 |
provider = UnrolledPatternProvider(n_q, flattening, delays)
91 |
pattern = provider.get_pattern(timesteps)
92 |
assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
93 |
94 |
@pytest.mark.parametrize("timesteps", [0, 1, 16])
95 |
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
96 |
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
97 |
def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
98 |
n_q = len(flattening)
99 |
max_delay = max(delays)
100 |
provider = UnrolledPatternProvider(n_q, flattening, delays)
101 |
pattern = provider.get_pattern(timesteps)
102 |
assert pattern.max_delay == max_delay
103 |
104 |
105 |
class TestPattern:
106 |
107 |
def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
108 |
"""Reference method to build the sequence from the pattern without using fancy scatter."""
109 |
bs, n_q, T = z.shape
110 |
z = z.cpu().numpy()
111 |
assert n_q == pattern.n_q
112 |
assert T <= pattern.timesteps
113 |
inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
114 |
inp[:] = special_token
115 |
for s, v in enumerate(pattern.layout):
116 |
for (t, q) in v:
117 |
if t < T:
118 |
inp[:, q, s] = z[:, q, t]
119 |
return torch.from_numpy(inp)
120 |
121 |
def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
122 |
"""Reference method to revert the sequence from the pattern without using fancy scatter."""
123 |
z = z.cpu().numpy()
124 |
bs, n_q, S = z.shape
125 |
assert pattern.n_q == n_q
126 |
inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
127 |
inp[:] = special_token
128 |
for s, v in enumerate(pattern.layout):
129 |
for (t, q) in v:
130 |
if t < pattern.timesteps:
131 |
inp[:, q, t] = z[:, q, s]
132 |
return torch.from_numpy(inp)
133 |
134 |
def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
135 |
"""Reference method to revert the logits from the pattern without using fancy scatter."""
136 |
z = z.cpu().numpy()
137 |
bs, card, n_q, S = z.shape
138 |
assert pattern.n_q == n_q
139 |
ref_layout = pattern.layout
140 |
inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
141 |
inp[:] = special_token
142 |
for s, v in enumerate(ref_layout[1:]):
143 |
if s < S:
144 |
for (t, q) in v:
145 |
if t < pattern.timesteps:
146 |
inp[:, :, q, t] = z[:, :, q, s]
147 |
return torch.from_numpy(inp)
148 |
149 |
def _get_pattern_providers(self, n_q: int):
150 |
pattern_provider_1 = ParallelPatternProvider(n_q)
151 |
pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
152 |
pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
153 |
pattern_provider_4 = UnrolledPatternProvider(
154 |
n_q, flattening=list(range(n_q)), delays=[0] * n_q
155 |
156 |
pattern_provider_5 = UnrolledPatternProvider(
157 |
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
158 |
159 |
pattern_provider_6 = UnrolledPatternProvider(
160 |
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
161 |
162 |
return [
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
172 |
@pytest.mark.parametrize("timesteps", [16, 72])
173 |
def test_build_pattern_sequence(self, n_q: int, timesteps: int):
174 |
bs = 2
175 |
card = 256
176 |
special_token = card
177 |
178 |
pattern_providers = self._get_pattern_providers(n_q)
179 |
for pattern_provider in pattern_providers:
180 |
pattern = pattern_provider.get_pattern(timesteps)
181 |
# we can correctly build the sequence from the pattern
182 |
z = torch.randint(0, card, (bs, n_q, timesteps))
183 |
ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
184 |
res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
185 |
assert (res == ref_res).float().mean() == 1.0
186 |
187 |
# expected assertion fails on the number of timesteps
188 |
invalid_timesteps = [timesteps + 1]
189 |
if pattern.num_sequence_steps != pattern.timesteps:
190 |
191 |
for i_timesteps in invalid_timesteps:
192 |
z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
193 |
with pytest.raises(AssertionError):
194 |
pattern.build_pattern_sequence(z2, special_token)
195 |
196 |
# expected assertion fails on the number of codebooks
197 |
invalid_qs = [0, n_q - 1, n_q + 1]
198 |
for i_q in invalid_qs:
199 |
z3 = torch.randint(0, card, (bs, i_q, timesteps))
200 |
with pytest.raises(AssertionError):
201 |
pattern.build_pattern_sequence(z3, special_token)
202 |
203 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
204 |
@pytest.mark.parametrize("timesteps", [16, 72])
205 |
def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
206 |
bs = 2
207 |
card = 256
208 |
special_token = card
209 |
210 |
pattern_providers = self._get_pattern_providers(n_q)
211 |
for pattern_provider in pattern_providers:
212 |
pattern = pattern_provider.get_pattern(timesteps)
213 |
# this works assuming previous tests are successful
214 |
z = torch.randint(0, card, (bs, n_q, timesteps))
215 |
s = self.ref_build_pattern_sequence(z, pattern, special_token)
216 |
ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
217 |
# ensure our reference script retrieve the original sequence
218 |
assert z.shape == ref_out.shape
219 |
assert (z == ref_out).float().mean() == 1.0
220 |
# now we can test the scatter version
221 |
out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
222 |
assert out.shape == ref_out.shape
223 |
assert (out == ref_out).float().mean() == 1.0
224 |
225 |
@pytest.mark.parametrize("n_q", [1, 4, 32])
226 |
@pytest.mark.parametrize("timesteps", [16, 72])
227 |
@pytest.mark.parametrize("card", [1, 2, 256, 1024])
228 |
def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
229 |
bs = 2
230 |
special_token = card
231 |
logits_special_token = float('nan')
232 |
233 |
pattern_providers = self._get_pattern_providers(n_q)
234 |
for pattern_provider in pattern_providers:
235 |
pattern = pattern_provider.get_pattern(timesteps)
236 |
# this works assuming previous tests are successful
237 |
z = torch.randint(0, card, (bs, n_q, timesteps))
238 |
s = self.ref_build_pattern_sequence(z, pattern, special_token)
239 |
logits = torch.randn((bs, card, n_q, s.shape[-1]))
240 |
ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
241 |
# ensure our reference script retrieve the original sequence
242 |
assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
243 |
# now we can test the scatter version
244 |
out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
245 |
assert out.shape == ref_out.shape
246 |
assert (out == ref_out).float().mean() == 1.0
@@ -1,203 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from itertools import product
8 |
import math
9 |
import random
10 |
11 |
import pytest
12 |
import torch
13 |
from torch import nn
14 |
15 |
from audiocraft.modules import (
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
def test_get_extra_padding_for_conv1d():
26 |
# TODO: Implement me!
27 |
28 |
29 |
30 |
def test_pad1d_zeros():
31 |
x = torch.randn(1, 1, 20)
32 |
33 |
xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
34 |
assert xp1.shape[-1] == 25
35 |
xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
36 |
assert xp2.shape[-1] == 30
37 |
xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
38 |
assert xp3.shape[-1] == 20
39 |
xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
40 |
assert xp4.shape[-1] == 60
41 |
42 |
with pytest.raises(AssertionError):
43 |
pad1d(x, (-1, 0), mode='constant', value=0.)
44 |
45 |
with pytest.raises(AssertionError):
46 |
pad1d(x, (0, -1), mode='constant', value=0.)
47 |
48 |
with pytest.raises(AssertionError):
49 |
pad1d(x, (-1, -1), mode='constant', value=0.)
50 |
51 |
52 |
def test_pad1d_reflect():
53 |
x = torch.randn(1, 1, 20)
54 |
55 |
xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
56 |
assert xp1.shape[-1] == 25
57 |
xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
58 |
assert xp2.shape[-1] == 30
59 |
xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
60 |
assert xp3.shape[-1] == 20
61 |
xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
62 |
assert xp4.shape[-1] == 60
63 |
64 |
with pytest.raises(AssertionError):
65 |
pad1d(x, (-1, 0), mode='reflect', value=0.)
66 |
67 |
with pytest.raises(AssertionError):
68 |
pad1d(x, (0, -1), mode='reflect', value=0.)
69 |
70 |
with pytest.raises(AssertionError):
71 |
pad1d(x, (-1, -1), mode='reflect', value=0.)
72 |
73 |
74 |
def test_unpad1d():
75 |
x = torch.randn(1, 1, 20)
76 |
77 |
u1 = unpad1d(x, (5, 5))
78 |
assert u1.shape[-1] == 10
79 |
u2 = unpad1d(x, (0, 5))
80 |
assert u2.shape[-1] == 15
81 |
u3 = unpad1d(x, (5, 0))
82 |
assert u3.shape[-1] == 15
83 |
u4 = unpad1d(x, (0, 0))
84 |
assert u4.shape[-1] == x.shape[-1]
85 |
86 |
with pytest.raises(AssertionError):
87 |
unpad1d(x, (-1, 0))
88 |
89 |
with pytest.raises(AssertionError):
90 |
unpad1d(x, (0, -1))
91 |
92 |
with pytest.raises(AssertionError):
93 |
unpad1d(x, (-1, -1))
94 |
95 |
96 |
class TestNormConv1d:
97 |
98 |
def test_norm_conv1d_modules(self):
99 |
N, C, T = 2, 2, random.randrange(1, 100_000)
100 |
t0 = torch.randn(N, C, T)
101 |
102 |
C_out, kernel_size, stride = 1, 4, 1
103 |
expected_out_length = int((T - kernel_size) / stride + 1)
104 |
wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
105 |
gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
106 |
nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
107 |
108 |
assert isinstance(wn_conv.norm, nn.Identity)
109 |
assert isinstance(wn_conv.conv, nn.Conv1d)
110 |
111 |
assert isinstance(gn_conv.norm, nn.GroupNorm)
112 |
assert isinstance(gn_conv.conv, nn.Conv1d)
113 |
114 |
assert isinstance(nn_conv.norm, nn.Identity)
115 |
assert isinstance(nn_conv.conv, nn.Conv1d)
116 |
117 |
for conv_layer in [wn_conv, gn_conv, nn_conv]:
118 |
out = conv_layer(t0)
119 |
assert isinstance(out, torch.Tensor)
120 |
assert list(out.shape) == [N, C_out, expected_out_length]
121 |
122 |
123 |
class TestNormConvTranspose1d:
124 |
125 |
def test_normalizations(self):
126 |
N, C, T = 2, 2, random.randrange(1, 100_000)
127 |
t0 = torch.randn(N, C, T)
128 |
129 |
C_out, kernel_size, stride = 1, 4, 1
130 |
expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
131 |
132 |
wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
133 |
gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
134 |
nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
135 |
136 |
assert isinstance(wn_convtr.norm, nn.Identity)
137 |
assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
138 |
139 |
assert isinstance(gn_convtr.norm, nn.GroupNorm)
140 |
assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
141 |
142 |
assert isinstance(nn_convtr.norm, nn.Identity)
143 |
assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
144 |
145 |
for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
146 |
out = convtr_layer(t0)
147 |
assert isinstance(out, torch.Tensor)
148 |
assert list(out.shape) == [N, C_out, expected_out_length]
149 |
150 |
151 |
class TestStreamableConv1d:
152 |
153 |
def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
154 |
# StreamableConv1d internally pads to make sure that the last window is full
155 |
padding_total = (kernel_size - 1) * dilation - (stride - 1)
156 |
n_frames = (length - kernel_size + padding_total) / stride + 1
157 |
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
158 |
return ideal_length // stride
159 |
160 |
def test_streamable_conv1d(self):
161 |
N, C, T = 2, 2, random.randrange(1, 100_000)
162 |
t0 = torch.randn(N, C, T)
163 |
C_out = 1
164 |
165 |
# conv params are [(kernel_size, stride, dilation)]
166 |
conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
167 |
for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
168 |
expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
169 |
sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
170 |
out = sconv(t0)
171 |
assert isinstance(out, torch.Tensor)
172 |
print(list(out.shape), [N, C_out, expected_out_length])
173 |
assert list(out.shape) == [N, C_out, expected_out_length]
174 |
175 |
176 |
class TestStreamableConvTranspose1d:
177 |
178 |
def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
179 |
padding_total = (kernel_size - stride)
180 |
return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
181 |
182 |
def test_streamable_convtr1d(self):
183 |
N, C, T = 2, 2, random.randrange(1, 100_000)
184 |
t0 = torch.randn(N, C, T)
185 |
186 |
C_out = 1
187 |
188 |
with pytest.raises(AssertionError):
189 |
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
190 |
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
191 |
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
192 |
193 |
# causal params are [(causal, trim_right)]
194 |
causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
195 |
# conv params are [(kernel_size, stride)]
196 |
conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
197 |
for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
198 |
expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
199 |
sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
200 |
causal=causal, trim_right_ratio=trim_right_ratio)
201 |
out = sconvtr(t0)
202 |
assert isinstance(out, torch.Tensor)
203 |
assert list(out.shape) == [N, C_out, expected_out_length]
@@ -1,32 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import random
8 |
import torch
9 |
10 |
from audiocraft.modules.lstm import StreamableLSTM
11 |
12 |
13 |
class TestStreamableLSTM:
14 |
15 |
def test_lstm(self):
16 |
B, C, T = 4, 2, random.randint(1, 100)
17 |
18 |
lstm = StreamableLSTM(C, 3, skip=False)
19 |
x = torch.randn(B, C, T)
20 |
y = lstm(x)
21 |
22 |
23 |
assert y.shape == torch.Size([B, C, T])
24 |
25 |
def test_lstm_skip(self):
26 |
B, C, T = 4, 2, random.randint(1, 100)
27 |
28 |
lstm = StreamableLSTM(C, 3, skip=True)
29 |
x = torch.randn(B, C, T)
30 |
y = lstm(x)
31 |
32 |
assert y.shape == torch.Size([B, C, T])
@@ -1,160 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import torch
8 |
9 |
from audiocraft.modules.rope import RotaryEmbedding
10 |
from audiocraft.modules.transformer import StreamingTransformer
11 |
12 |
13 |
def test_rope():
14 |
B, T, H, C = 8, 75, 16, 128
15 |
16 |
rope = RotaryEmbedding(dim=C)
17 |
xq = torch.rand((B, T, H, C))
18 |
xk = torch.rand((B, T, H, C))
19 |
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
20 |
21 |
assert list(xq_out.shape) == [B, T, H, C]
22 |
assert list(xk_out.shape) == [B, T, H, C]
23 |
24 |
25 |
def test_rope_io_dtypes():
26 |
B, T, H, C = 8, 75, 16, 128
27 |
28 |
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
29 |
rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
30 |
31 |
# Test bfloat16 inputs w/ both 32 and 64 precision rope.
32 |
xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
33 |
xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
34 |
xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
35 |
assert xq_out.dtype == torch.bfloat16
36 |
xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
37 |
assert xq_out.dtype == torch.bfloat16
38 |
39 |
# Test float32 inputs w/ both 32 and 64 precision rope.
40 |
xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
41 |
xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
42 |
xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
43 |
assert xq_out.dtype == torch.float32
44 |
xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
45 |
assert xq_out.dtype == torch.float32
46 |
47 |
48 |
def test_transformer_with_rope():
49 |
50 |
for pos in ['rope', 'sin_rope']:
51 |
tr = StreamingTransformer(
52 |
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
53 |
54 |
55 |
steps = 12
56 |
x = torch.randn(3, steps, 16)
57 |
58 |
out = tr(x)
59 |
assert list(out.shape) == list(x.shape)
60 |
61 |
62 |
63 |
def test_rope_streaming():
64 |
65 |
tr = StreamingTransformer(
66 |
16, 4, 2, causal=True, dropout=0.,
67 |
custom=True, positional_embedding='rope')
68 |
69 |
steps = 12
70 |
x = torch.randn(3, steps, 16)
71 |
72 |
ref = tr(x)
73 |
74 |
with tr.streaming():
75 |
outs = []
76 |
frame_sizes = [1] * steps
77 |
78 |
for frame_size in frame_sizes:
79 |
frame = x[:, :frame_size]
80 |
x = x[:, frame_size:]
81 |
82 |
83 |
out =, dim=1)
84 |
assert list(out.shape) == [3, steps, 16]
85 |
delta = torch.norm(out - ref) / torch.norm(out)
86 |
assert delta < 1e-6, delta
87 |
88 |
89 |
90 |
def test_rope_streaming_past_context():
91 |
92 |
93 |
for context in [None, 10]:
94 |
tr = StreamingTransformer(
95 |
16, 4, 1 if context else 2,
96 |
causal=True, past_context=context, custom=True,
97 |
dropout=0., positional_embedding='rope')
98 |
99 |
100 |
steps = 20
101 |
x = torch.randn(3, steps, 16)
102 |
ref = tr(x)
103 |
104 |
with tr.streaming():
105 |
outs = []
106 |
frame_sizes = [1] * steps
107 |
108 |
for frame_size in frame_sizes:
109 |
frame = x[:, :frame_size]
110 |
x = x[:, frame_size:]
111 |
112 |
113 |
out =, dim=1)
114 |
assert list(out.shape) == [3, steps, 16]
115 |
delta = torch.norm(out - ref) / torch.norm(out)
116 |
assert delta < 1e-6, delta
117 |
118 |
119 |
def test_rope_memory_efficient():
120 |
121 |
tr = StreamingTransformer(
122 |
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
123 |
124 |
tr_mem_efficient = StreamingTransformer(
125 |
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
126 |
127 |
128 |
129 |
steps = 12
130 |
x = torch.randn(3, steps, 16)
131 |
132 |
with torch.no_grad():
133 |
y = tr(x)
134 |
y2 = tr_mem_efficient(x)
135 |
# Check at float precision b/c this is the rope default.
136 |
assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
137 |
138 |
139 |
def test_rope_with_xpos():
140 |
B, T, H, C = 8, 75, 16, 128
141 |
142 |
rope = RotaryEmbedding(dim=C, xpos=True)
143 |
xq = torch.rand((B, T, H, C))
144 |
xk = torch.rand((B, T, H, C))
145 |
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
146 |
147 |
assert list(xq_out.shape) == [B, T, H, C]
148 |
assert list(xk_out.shape) == [B, T, H, C]
149 |
150 |
151 |
def test_positional_scale():
152 |
B, T, H, C = 8, 75, 16, 128
153 |
154 |
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
155 |
xq = torch.rand((B, T, H, C))
156 |
xk = torch.rand((B, T, H, C))
157 |
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
158 |
159 |
assert torch.allclose(xq, xq_out)
160 |
assert torch.allclose(xk, xk_out)
@@ -1,115 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from itertools import product
8 |
9 |
import pytest
10 |
import torch
11 |
12 |
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
13 |
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
14 |
15 |
16 |
class TestSEANetModel:
17 |
18 |
def test_base(self):
19 |
encoder = SEANetEncoder()
20 |
decoder = SEANetDecoder()
21 |
22 |
x = torch.randn(1, 1, 24000)
23 |
z = encoder(x)
24 |
assert list(z.shape) == [1, 128, 75], z.shape
25 |
y = decoder(z)
26 |
assert y.shape == x.shape, (x.shape, y.shape)
27 |
28 |
def test_causal(self):
29 |
encoder = SEANetEncoder(causal=True)
30 |
decoder = SEANetDecoder(causal=True)
31 |
x = torch.randn(1, 1, 24000)
32 |
33 |
z = encoder(x)
34 |
assert list(z.shape) == [1, 128, 75], z.shape
35 |
y = decoder(z)
36 |
assert y.shape == x.shape, (x.shape, y.shape)
37 |
38 |
def test_conv_skip_connection(self):
39 |
encoder = SEANetEncoder(true_skip=False)
40 |
decoder = SEANetDecoder(true_skip=False)
41 |
42 |
x = torch.randn(1, 1, 24000)
43 |
z = encoder(x)
44 |
assert list(z.shape) == [1, 128, 75], z.shape
45 |
y = decoder(z)
46 |
assert y.shape == x.shape, (x.shape, y.shape)
47 |
48 |
def test_seanet_encoder_decoder_final_act(self):
49 |
encoder = SEANetEncoder(true_skip=False)
50 |
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
51 |
52 |
x = torch.randn(1, 1, 24000)
53 |
z = encoder(x)
54 |
assert list(z.shape) == [1, 128, 75], z.shape
55 |
y = decoder(z)
56 |
assert y.shape == x.shape, (x.shape, y.shape)
57 |
58 |
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
59 |
n_blocks = 0
60 |
for layer in encoder.model:
61 |
if isinstance(layer, StreamableConv1d):
62 |
n_blocks += 1
63 |
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
64 |
elif isinstance(layer, SEANetResnetBlock):
65 |
for resnet_layer in layer.block:
66 |
if isinstance(resnet_layer, StreamableConv1d):
67 |
# here we add + 1 to n_blocks as we increment n_blocks just after the block
68 |
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
69 |
70 |
def test_encoder_disable_norm(self):
71 |
n_residuals = [0, 1, 3]
72 |
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
73 |
norms = ['weight_norm', 'none']
74 |
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
75 |
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
76 |
77 |
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
78 |
79 |
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
80 |
n_blocks = 0
81 |
for layer in decoder.model:
82 |
if isinstance(layer, StreamableConv1d):
83 |
n_blocks += 1
84 |
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
85 |
elif isinstance(layer, StreamableConvTranspose1d):
86 |
n_blocks += 1
87 |
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
88 |
elif isinstance(layer, SEANetResnetBlock):
89 |
for resnet_layer in layer.block:
90 |
if isinstance(resnet_layer, StreamableConv1d):
91 |
assert resnet_layer.conv.norm_type == 'none' \
92 |
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
93 |
94 |
def test_decoder_disable_norm(self):
95 |
n_residuals = [0, 1, 3]
96 |
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
97 |
norms = ['weight_norm', 'none']
98 |
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
99 |
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
100 |
101 |
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
102 |
103 |
def test_disable_norm_raises_exception(self):
104 |
# Invalid disable_norm_outer_blocks values raise exceptions
105 |
with pytest.raises(AssertionError):
106 |
107 |
108 |
with pytest.raises(AssertionError):
109 |
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
110 |
111 |
with pytest.raises(AssertionError):
112 |
113 |
114 |
with pytest.raises(AssertionError):
115 |
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
@@ -1,247 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from itertools import product
8 |
9 |
import pytest
10 |
import torch
11 |
12 |
from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
13 |
14 |
15 |
def test_transformer_causal_streaming():
16 |
17 |
18 |
for context, custom in product([None, 10], [False, True]):
19 |
# Test that causality and receptive fields are properly handled.
20 |
# looking at the gradients
21 |
tr = StreamingTransformer(
22 |
16, 4, 1 if context else 2,
23 |
causal=True, past_context=context, custom=custom,
24 |
25 |
steps = 20
26 |
for k in [0, 10, 15, 19]:
27 |
x = torch.randn(4, steps, 16, requires_grad=True)
28 |
y = tr(x)
29 |
y[:, k].abs().sum().backward()
30 |
if k + 1 < steps:
31 |
assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
32 |
assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
33 |
if context is not None and k > context:
34 |
limit = k - context - 1
35 |
assert torch.allclose(x.grad[:, :limit],
36 |
torch.tensor(0.)), x.grad[:, :limit].norm()
37 |
38 |
# Now check that streaming gives the same result at batch eval.
39 |
x = torch.randn(4, steps, 16)
40 |
y = tr(x)
41 |
ys = []
42 |
with tr.streaming():
43 |
for k in range(steps):
44 |
chunk = x[:, k:k + 1, :]
45 |
46 |
y_stream =, dim=1)
47 |
delta = torch.norm(y_stream - y) / torch.norm(y)
48 |
assert delta < 1e-6, delta
49 |
50 |
51 |
def test_transformer_vs_pytorch():
52 |
53 |
# Check that in the non causal setting, we get the same result as
54 |
# PyTorch Transformer encoder.
55 |
for custom in [False, True]:
56 |
tr = StreamingTransformer(
57 |
16, 4, 2,
58 |
causal=False, custom=custom, dropout=0., positional_scale=0.)
59 |
layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
60 |
tr_ref = torch.nn.TransformerEncoder(layer, 2)
61 |
62 |
63 |
x = torch.randn(4, 20, 16)
64 |
y = tr(x)
65 |
y2 = tr_ref(x)
66 |
delta = torch.norm(y2 - y) / torch.norm(y)
67 |
assert delta < 1e-6, delta
68 |
69 |
70 |
def test_streaming_api():
71 |
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
72 |
73 |
steps = 12
74 |
x = torch.randn(1, steps, 16)
75 |
76 |
with torch.no_grad():
77 |
with tr.streaming():
78 |
_ = tr(x[:, :1])
79 |
state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
80 |
y = tr(x[:, 1:2])
81 |
82 |
y2 = tr(x[:, 1:2])
83 |
assert torch.allclose(y, y2), (y - y2).norm()
84 |
assert tr.flush() is None
85 |
86 |
87 |
def test_memory_efficient():
88 |
89 |
tr = StreamingTransformer(
90 |
16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
91 |
tr_mem_efficient = StreamingTransformer(
92 |
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
93 |
94 |
95 |
steps = 12
96 |
x = torch.randn(3, steps, 16)
97 |
98 |
with torch.no_grad():
99 |
y = tr(x)
100 |
y2 = tr_mem_efficient(x)
101 |
assert torch.allclose(y, y2), (y - y2).norm()
102 |
103 |
104 |
def test_attention_as_float32():
105 |
106 |
cases = [
107 |
{'custom': True},
108 |
{'custom': False},
109 |
110 |
for case in cases:
111 |
tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
112 |
tr_float32 = StreamingTransformer(
113 |
16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
114 |
if not case['custom']:
115 |
# we are not using autocast here because it doesn't really
116 |
# work as expected on CPU, so we have to manually cast the weights of the MHA.
117 |
for layer in tr_float32.layers:
118 |
119 |
120 |
steps = 12
121 |
x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
122 |
123 |
with torch.no_grad():
124 |
y = tr(x)
125 |
y2 = tr_float32(x)
126 |
assert not torch.allclose(y, y2), (y - y2).norm()
127 |
128 |
129 |
130 |
def test_streaming_memory_efficient():
131 |
132 |
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
133 |
tr_mem_efficient = StreamingTransformer(
134 |
16, 4, 2, dropout=0., memory_efficient=True, causal=True)
135 |
136 |
137 |
138 |
steps = 12
139 |
x = torch.randn(3, steps, 16)
140 |
141 |
ref = tr(x)
142 |
143 |
with tr_mem_efficient.streaming():
144 |
outs = []
145 |
# frame_sizes = [2] + [1] * (steps - 2)
146 |
frame_sizes = [1] * steps
147 |
148 |
for frame_size in frame_sizes:
149 |
frame = x[:, :frame_size]
150 |
x = x[:, frame_size:]
151 |
152 |
153 |
out =, dim=1)
154 |
delta = torch.norm(out - ref) / torch.norm(out)
155 |
assert delta < 1e-6, delta
156 |
157 |
158 |
def test_cross_attention():
159 |
160 |
for norm_first in [True, False]:
161 |
m = StreamingTransformer(
162 |
16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
163 |
m_cross = StreamingTransformer(
164 |
16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
165 |
m_cross.load_state_dict(m.state_dict(), strict=False)
166 |
x = torch.randn(2, 5, 16)
167 |
cross_x = torch.randn(2, 3, 16)
168 |
y_ref = m(x)
169 |
y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
170 |
# With norm_first, the two should be exactly yhe same,
171 |
# but with norm_first=False, we get 2 normalization in a row
172 |
# and the epsilon value leads to a tiny change.
173 |
atol = 0. if norm_first else 1e-6
174 |
print((y_ref - y_cross_zero).norm() / y_ref.norm())
175 |
assert torch.allclose(y_ref, y_cross_zero, atol=atol)
176 |
177 |
# We now expect a difference even with a generous atol of 1e-2.
178 |
y_cross = m_cross(x, cross_attention_src=cross_x)
179 |
assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
180 |
181 |
with pytest.raises(AssertionError):
182 |
_ = m_cross(x)
183 |
_ = m(x, cross_attention_src=cross_x)
184 |
185 |
186 |
def test_cross_attention_compat():
187 |
188 |
num_heads = 2
189 |
dim = num_heads * 64
190 |
with pytest.raises(AssertionError):
191 |
StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
192 |
193 |
cross_attn = StreamingMultiheadAttention(
194 |
dim, num_heads, dropout=0, cross_attention=True, custom=True)
195 |
ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
196 |
197 |
# We can load the regular attention state dict
198 |
# so we have compat when loading old checkpoints.
199 |
200 |
201 |
queries = torch.randn(3, 7, dim)
202 |
keys = torch.randn(3, 9, dim)
203 |
values = torch.randn(3, 9, dim)
204 |
205 |
y = cross_attn(queries, keys, values)[0]
206 |
y_ref = ref_attn(queries, keys, values)[0]
207 |
assert torch.allclose(y, y_ref, atol=1e-7)
208 |
209 |
# Now let's check that streaming is working properly.
210 |
with cross_attn.streaming():
211 |
ys = []
212 |
for step in range(queries.shape[1]):
213 |
ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
214 |
y_streaming =, dim=1)
215 |
assert torch.allclose(y_streaming, y, atol=1e-7)
216 |
217 |
218 |
def test_repeat_kv():
219 |
220 |
num_heads = 8
221 |
kv_repeat = 4
222 |
dim = num_heads * 64
223 |
with pytest.raises(AssertionError):
224 |
mha = StreamingMultiheadAttention(
225 |
dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
226 |
mha = StreamingMultiheadAttention(
227 |
dim, num_heads, causal=True, kv_repeat=kv_repeat)
228 |
mha = StreamingMultiheadAttention(
229 |
dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
230 |
x = torch.randn(4, 18, dim)
231 |
y = mha(x, x, x)[0]
232 |
assert x.shape == y.shape
233 |
234 |
235 |
def test_qk_layer_norm():
236 |
237 |
tr = StreamingTransformer(
238 |
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
239 |
steps = 12
240 |
x = torch.randn(3, steps, 16)
241 |
y = tr(x)
242 |
243 |
tr = StreamingTransformer(
244 |
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
245 |
z = torch.randn(3, 21, 16)
246 |
y = tr(x, cross_attention_src=z)
247 |
assert y.shape == x.shape
@@ -1,18 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import torch
8 |
9 |
from audiocraft.quantization.vq import ResidualVectorQuantizer
10 |
11 |
12 |
class TestResidualVectorQuantizer:
13 |
14 |
def test_rvq(self):
15 |
x = torch.randn(1, 16, 2048)
16 |
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
17 |
res = vq(x, 1.)
18 |
assert res.x.shape == torch.Size([1, 16, 2048])
@@ -1,5 +0,0 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.