Update model.py
Browse files
model.py
CHANGED
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
|
12 |
import torch.nn.init as init
|
13 |
from torch import nn, Tensor
|
14 |
import numpy as np
|
|
|
15 |
import matplotlib.pyplot as plt
|
16 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
17 |
from functools import partial
|
@@ -249,24 +250,20 @@ def sinusoids(length, channels, max_timescale=10000):
|
|
249 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
250 |
|
251 |
class rotary(nn.Module):
|
252 |
-
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False
|
253 |
super(rotary, self).__init__()
|
254 |
|
|
|
255 |
self.dims = dims
|
256 |
self.head = head
|
257 |
self.head_dim = dims // head
|
258 |
-
self.dim = self.head_dim
|
259 |
-
self.max_ctx = max_ctx
|
260 |
-
self.theta = theta
|
261 |
self.radii = radii
|
262 |
-
self.
|
263 |
-
self.use_pbias = use_pbias
|
264 |
-
self.spec_shape = spec_shape
|
265 |
self.debug = debug
|
266 |
self.counter = 0
|
267 |
self.last_theta = None
|
268 |
|
269 |
-
|
270 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
271 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
272 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
@@ -287,7 +284,8 @@ class rotary(nn.Module):
|
|
287 |
self.theta.data.copy_(theta)
|
288 |
|
289 |
def get_pitch_bias(self, f0):
|
290 |
-
f0
|
|
|
291 |
f0_flat = f0.squeeze().float()
|
292 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
293 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
@@ -317,8 +315,6 @@ class rotary(nn.Module):
|
|
317 |
return f0[idx]
|
318 |
|
319 |
def align_f0(self, ctx, f0):
|
320 |
-
# f0 = self.return_f0()
|
321 |
-
# f0 = self.f0proj(f0)
|
322 |
if f0.dim() == 3:
|
323 |
batch, length, dims = f0.shape
|
324 |
if length == ctx:
|
@@ -345,24 +341,23 @@ class rotary(nn.Module):
|
|
345 |
return f0[idx, :]
|
346 |
|
347 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
348 |
-
f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
|
349 |
-
if f0 is not None and f0.dim() == 2:
|
350 |
-
if f0.shape[0] == 1:
|
351 |
-
f0 = f0.squeeze(0)
|
352 |
-
else:
|
353 |
-
f0 = f0.view(-1)
|
354 |
-
|
355 |
-
if "rot1" in self.debug and self.counter % 100 == 0:
|
356 |
-
print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
|
357 |
-
|
358 |
if isinstance(x, int):
|
359 |
ctx = x
|
|
|
|
|
360 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
361 |
batch, ctx, dims = x.shape
|
362 |
else:
|
363 |
batch, head, ctx, head_dim = x.shape
|
364 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
if f0 is not None:
|
367 |
f0_mean = f0.mean()
|
368 |
theta = f0_mean + self.theta
|
@@ -388,6 +383,9 @@ class rotary(nn.Module):
|
|
388 |
radius = torch.ones_like(freqs)
|
389 |
freqs = torch.polar(radius, freqs)
|
390 |
|
|
|
|
|
|
|
391 |
if "rot3" in self.debug and self.counter % 100 == 0:
|
392 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
393 |
|
@@ -859,13 +857,11 @@ class AudioEncoder(nn.Module):
|
|
859 |
),
|
860 |
"envelope": nn.ModuleList(
|
861 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
862 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
863 |
-
for _ in range(layer)] if "envelope" in features else None
|
864 |
),
|
865 |
"phase": nn.ModuleList(
|
866 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
867 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
868 |
-
for _ in range(layer)] if "phase" in features else None
|
869 |
)
|
870 |
})
|
871 |
|
@@ -899,7 +895,7 @@ class AudioEncoder(nn.Module):
|
|
899 |
|
900 |
class TextDecoder(nn.Module):
|
901 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
902 |
-
debug: List[str], features: List[str]
|
903 |
super(TextDecoder, self).__init__()
|
904 |
|
905 |
self.ctx = ctx
|
@@ -909,7 +905,6 @@ class TextDecoder(nn.Module):
|
|
909 |
self.debug = debug
|
910 |
self.counter = 0
|
911 |
self.dropout = 0.01
|
912 |
-
self.sequential = sequential
|
913 |
self.features = features
|
914 |
|
915 |
self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
|
@@ -931,7 +926,7 @@ class TextDecoder(nn.Module):
|
|
931 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
932 |
self.register_buffer("mask", mask, persistent=False)
|
933 |
|
934 |
-
def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
|
935 |
enc = dict_to(enc, device, dtype)
|
936 |
x = x.to(device)
|
937 |
bln = self.blend
|
@@ -943,17 +938,25 @@ class TextDecoder(nn.Module):
|
|
943 |
x = self.token(x) + self.positional[:x.shape[1]]
|
944 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
945 |
|
|
|
|
|
|
|
|
|
946 |
for block in self.block:
|
947 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
948 |
|
949 |
for f in order:
|
950 |
if f in enc:
|
|
|
951 |
xa = enc[f]
|
952 |
for block in self.blocks[f]:
|
953 |
out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
|
954 |
|
955 |
-
|
956 |
-
|
|
|
|
|
|
|
957 |
|
958 |
if "decoder" in self.debug and self.counter % 100 == 0:
|
959 |
print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
|
@@ -1019,19 +1022,16 @@ class Echo(nn.Module):
|
|
1019 |
return self.decoder(input_ids, encoder_output)
|
1020 |
|
1021 |
def forward(self,
|
1022 |
-
decoder_input_ids=None,
|
1023 |
labels=None,
|
1024 |
waveform: Optional[torch.Tensor]=None,
|
1025 |
input_ids=None,
|
1026 |
spectrogram: torch.Tensor=None,
|
1027 |
pitch: Optional[torch.Tensor]=None,
|
1028 |
f0: Optional[torch.Tensor]=None,
|
1029 |
-
f0d: Optional[torch.Tensor]=None,
|
1030 |
envelope: Optional[torch.Tensor]=None,
|
1031 |
phase: Optional[torch.Tensor]=None,
|
1032 |
) -> Dict[str, torch.Tensor]:
|
1033 |
|
1034 |
-
decoder_input_ids = input_ids
|
1035 |
encoder_inputs = {}
|
1036 |
if spectrogram is not None:
|
1037 |
encoder_inputs["spectrogram"] = spectrogram
|
@@ -1046,11 +1046,6 @@ class Echo(nn.Module):
|
|
1046 |
if f0 is not None:
|
1047 |
encoder_inputs["f0"] = f0
|
1048 |
|
1049 |
-
|
1050 |
-
# if f0 is not None:
|
1051 |
-
# f0 = f0.squeeze(0)
|
1052 |
-
# self.update_base(f0)
|
1053 |
-
|
1054 |
encoder_outputs = self.encoder(encoder_inputs)
|
1055 |
logits = self.decoder(input_ids, encoder_outputs)
|
1056 |
|
@@ -1063,10 +1058,6 @@ class Echo(nn.Module):
|
|
1063 |
return {
|
1064 |
"logits": logits,
|
1065 |
"loss": loss,
|
1066 |
-
# "labels": labels,
|
1067 |
-
# "input_ids": input_ids,
|
1068 |
-
# "decoder_input_ids": decoder_input_ids,
|
1069 |
-
# "encoder_output": encoder_outputs,
|
1070 |
}
|
1071 |
|
1072 |
@property
|
@@ -1617,9 +1608,9 @@ def get_training_args(
|
|
1617 |
per_device_train_batch_size=1,
|
1618 |
per_device_eval_batch_size=1,
|
1619 |
gradient_accumulation_steps=1,
|
1620 |
-
eval_accumulation_steps=
|
1621 |
eval_strategy="steps",
|
1622 |
-
save_strategy="
|
1623 |
max_steps=max_steps,
|
1624 |
save_steps=save_steps,
|
1625 |
eval_steps=eval_steps,
|
@@ -1703,7 +1694,7 @@ def main():
|
|
1703 |
training_args = sanity(sanity_check)
|
1704 |
dataset_config = {
|
1705 |
"spectrogram": True,
|
1706 |
-
"waveforms":
|
1707 |
"pitch": False,
|
1708 |
"downsamples": False,
|
1709 |
"frequency": False,
|
@@ -1752,6 +1743,10 @@ def main():
|
|
1752 |
if __name__ == "__main__":
|
1753 |
main()
|
1754 |
|
1755 |
-
|
1756 |
-
|
|
|
|
|
|
|
|
|
1757 |
|
|
|
12 |
import torch.nn.init as init
|
13 |
from torch import nn, Tensor
|
14 |
import numpy as np
|
15 |
+
from einops import rearrange
|
16 |
import matplotlib.pyplot as plt
|
17 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
18 |
from functools import partial
|
|
|
250 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
251 |
|
252 |
class rotary(nn.Module):
|
253 |
+
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False):
|
254 |
super(rotary, self).__init__()
|
255 |
|
256 |
+
self.use_pbias = use_pbias
|
257 |
self.dims = dims
|
258 |
self.head = head
|
259 |
self.head_dim = dims // head
|
|
|
|
|
|
|
260 |
self.radii = radii
|
261 |
+
self.dim = self.head_dim
|
|
|
|
|
262 |
self.debug = debug
|
263 |
self.counter = 0
|
264 |
self.last_theta = None
|
265 |
|
266 |
+
self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
|
267 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
268 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
269 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
|
|
284 |
self.theta.data.copy_(theta)
|
285 |
|
286 |
def get_pitch_bias(self, f0):
|
287 |
+
if f0 is None:
|
288 |
+
return None
|
289 |
f0_flat = f0.squeeze().float()
|
290 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
291 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
|
|
315 |
return f0[idx]
|
316 |
|
317 |
def align_f0(self, ctx, f0):
|
|
|
|
|
318 |
if f0.dim() == 3:
|
319 |
batch, length, dims = f0.shape
|
320 |
if length == ctx:
|
|
|
341 |
return f0[idx, :]
|
342 |
|
343 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
if isinstance(x, int):
|
345 |
ctx = x
|
346 |
+
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
347 |
+
batch, ctx = x.shape
|
348 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
349 |
batch, ctx, dims = x.shape
|
350 |
else:
|
351 |
batch, head, ctx, head_dim = x.shape
|
352 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
353 |
|
354 |
+
f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
|
355 |
+
if f0 is not None and f0.dim() == 2:
|
356 |
+
if f0.shape[0] == 1:
|
357 |
+
f0 = f0.squeeze(0)
|
358 |
+
else:
|
359 |
+
f0 = f0.view(-1)
|
360 |
+
|
361 |
if f0 is not None:
|
362 |
f0_mean = f0.mean()
|
363 |
theta = f0_mean + self.theta
|
|
|
383 |
radius = torch.ones_like(freqs)
|
384 |
freqs = torch.polar(radius, freqs)
|
385 |
|
386 |
+
if "rot1" in self.debug and self.counter % 100 == 0:
|
387 |
+
print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
|
388 |
+
|
389 |
if "rot3" in self.debug and self.counter % 100 == 0:
|
390 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
391 |
|
|
|
857 |
),
|
858 |
"envelope": nn.ModuleList(
|
859 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
860 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None
|
|
|
861 |
),
|
862 |
"phase": nn.ModuleList(
|
863 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
864 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None
|
|
|
865 |
)
|
866 |
})
|
867 |
|
|
|
895 |
|
896 |
class TextDecoder(nn.Module):
|
897 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
898 |
+
debug: List[str], features: List[str]):
|
899 |
super(TextDecoder, self).__init__()
|
900 |
|
901 |
self.ctx = ctx
|
|
|
905 |
self.debug = debug
|
906 |
self.counter = 0
|
907 |
self.dropout = 0.01
|
|
|
908 |
self.features = features
|
909 |
|
910 |
self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
|
|
|
926 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
927 |
self.register_buffer("mask", mask, persistent=False)
|
928 |
|
929 |
+
def forward(self, x, enc, order=None, layer='decoder', sequential=False) -> Tensor:
|
930 |
enc = dict_to(enc, device, dtype)
|
931 |
x = x.to(device)
|
932 |
bln = self.blend
|
|
|
938 |
x = self.token(x) + self.positional[:x.shape[1]]
|
939 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
940 |
|
941 |
+
# ctx = x.shape[1]
|
942 |
+
# freqs = self.rotary(ctx)
|
943 |
+
# x = self.rotary.apply_rotary(x, freqs)
|
944 |
+
|
945 |
for block in self.block:
|
946 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
947 |
|
948 |
for f in order:
|
949 |
if f in enc:
|
950 |
+
seq = x
|
951 |
xa = enc[f]
|
952 |
for block in self.blocks[f]:
|
953 |
out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
|
954 |
|
955 |
+
if sequential:
|
956 |
+
x = seq
|
957 |
+
else:
|
958 |
+
a = torch.sigmoid(bln[f])
|
959 |
+
x = a * out + (1 - a) * x
|
960 |
|
961 |
if "decoder" in self.debug and self.counter % 100 == 0:
|
962 |
print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
|
|
|
1022 |
return self.decoder(input_ids, encoder_output)
|
1023 |
|
1024 |
def forward(self,
|
|
|
1025 |
labels=None,
|
1026 |
waveform: Optional[torch.Tensor]=None,
|
1027 |
input_ids=None,
|
1028 |
spectrogram: torch.Tensor=None,
|
1029 |
pitch: Optional[torch.Tensor]=None,
|
1030 |
f0: Optional[torch.Tensor]=None,
|
|
|
1031 |
envelope: Optional[torch.Tensor]=None,
|
1032 |
phase: Optional[torch.Tensor]=None,
|
1033 |
) -> Dict[str, torch.Tensor]:
|
1034 |
|
|
|
1035 |
encoder_inputs = {}
|
1036 |
if spectrogram is not None:
|
1037 |
encoder_inputs["spectrogram"] = spectrogram
|
|
|
1046 |
if f0 is not None:
|
1047 |
encoder_inputs["f0"] = f0
|
1048 |
|
|
|
|
|
|
|
|
|
|
|
1049 |
encoder_outputs = self.encoder(encoder_inputs)
|
1050 |
logits = self.decoder(input_ids, encoder_outputs)
|
1051 |
|
|
|
1058 |
return {
|
1059 |
"logits": logits,
|
1060 |
"loss": loss,
|
|
|
|
|
|
|
|
|
1061 |
}
|
1062 |
|
1063 |
@property
|
|
|
1608 |
per_device_train_batch_size=1,
|
1609 |
per_device_eval_batch_size=1,
|
1610 |
gradient_accumulation_steps=1,
|
1611 |
+
eval_accumulation_steps=None,
|
1612 |
eval_strategy="steps",
|
1613 |
+
save_strategy="no",
|
1614 |
max_steps=max_steps,
|
1615 |
save_steps=save_steps,
|
1616 |
eval_steps=eval_steps,
|
|
|
1694 |
training_args = sanity(sanity_check)
|
1695 |
dataset_config = {
|
1696 |
"spectrogram": True,
|
1697 |
+
"waveforms": True,
|
1698 |
"pitch": False,
|
1699 |
"downsamples": False,
|
1700 |
"frequency": False,
|
|
|
1743 |
if __name__ == "__main__":
|
1744 |
main()
|
1745 |
|
1746 |
+
# from tensorboard import program
|
1747 |
+
# log_dir = "./output/logs"
|
1748 |
+
# tb = program.TensorBoard()
|
1749 |
+
# tb.configure(argv=[None, '--logdir', log_dir])
|
1750 |
+
# url = tb.launch()
|
1751 |
+
# print(f"TensorBoard started at {url}")
|
1752 |
|