Update model.py
Browse files
model.py
CHANGED
@@ -271,6 +271,7 @@ class rotary(nn.Module):
|
|
271 |
def return_f0(self, f0=None):
|
272 |
if f0 is not None:
|
273 |
self.f0 = f0
|
|
|
274 |
return f0.squeeze(0).to(device, dtype)
|
275 |
elif hasattr(self, 'f0') and self.f0 is not None:
|
276 |
return self.f0.squeeze(0).to(device, dtype)
|
@@ -303,18 +304,17 @@ class rotary(nn.Module):
|
|
303 |
return f0.to(device=device, dtype=dtype)
|
304 |
|
305 |
def synth_f0(self, f0, ctx):
|
306 |
-
f0 = self.f0proj(f0)
|
307 |
-
|
308 |
if f0.dim() == 1:
|
309 |
length = f0.shape[0]
|
310 |
if length == ctx:
|
311 |
return f0
|
312 |
frames = length / ctx
|
313 |
idx = torch.arange(ctx, device=f0.device)
|
314 |
-
# return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
|
315 |
return f0[idx]
|
316 |
|
317 |
def align_f0(self, ctx, f0):
|
|
|
318 |
if f0.dim() == 3:
|
319 |
batch, length, dims = f0.shape
|
320 |
if length == ctx:
|
@@ -341,6 +341,7 @@ class rotary(nn.Module):
|
|
341 |
return f0[idx, :]
|
342 |
|
343 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
|
|
344 |
if isinstance(x, int):
|
345 |
ctx = x
|
346 |
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
@@ -350,8 +351,8 @@ class rotary(nn.Module):
|
|
350 |
else:
|
351 |
batch, head, ctx, head_dim = x.shape
|
352 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
353 |
-
|
354 |
-
f0 = enc.get("f0") if enc is not None else None
|
355 |
if f0 is not None and f0.dim() == 2:
|
356 |
if f0.shape[0] == 1:
|
357 |
f0 = f0.squeeze(0)
|
@@ -362,7 +363,7 @@ class rotary(nn.Module):
|
|
362 |
f0_mean = f0.mean()
|
363 |
theta = f0_mean + self.theta
|
364 |
else:
|
365 |
-
theta =
|
366 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
367 |
self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
368 |
|
@@ -378,13 +379,21 @@ class rotary(nn.Module):
|
|
378 |
idx = torch.arange(ctx, device=f0.device)
|
379 |
idx = (idx * F).long().clamp(0, L - 1)
|
380 |
radius = radius[idx]
|
|
|
381 |
radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
|
|
|
382 |
else:
|
383 |
radius = torch.ones_like(freqs)
|
384 |
freqs = torch.polar(radius, freqs)
|
385 |
|
386 |
-
if "
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
if "rot3" in self.debug and self.counter % 100 == 0:
|
390 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
@@ -410,6 +419,19 @@ class rotary(nn.Module):
|
|
410 |
x1 = x1.view(orig_shape)
|
411 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
class MultiheadA(nn.Module):
|
414 |
_seen = set()
|
415 |
rbf = False
|
|
|
271 |
def return_f0(self, f0=None):
|
272 |
if f0 is not None:
|
273 |
self.f0 = f0
|
274 |
+
self.update_base(f0)
|
275 |
return f0.squeeze(0).to(device, dtype)
|
276 |
elif hasattr(self, 'f0') and self.f0 is not None:
|
277 |
return self.f0.squeeze(0).to(device, dtype)
|
|
|
304 |
return f0.to(device=device, dtype=dtype)
|
305 |
|
306 |
def synth_f0(self, f0, ctx):
|
307 |
+
# f0 = self.f0proj(f0)
|
|
|
308 |
if f0.dim() == 1:
|
309 |
length = f0.shape[0]
|
310 |
if length == ctx:
|
311 |
return f0
|
312 |
frames = length / ctx
|
313 |
idx = torch.arange(ctx, device=f0.device)
|
|
|
314 |
return f0[idx]
|
315 |
|
316 |
def align_f0(self, ctx, f0):
|
317 |
+
f0 = self.f0proj(f0)
|
318 |
if f0.dim() == 3:
|
319 |
batch, length, dims = f0.shape
|
320 |
if length == ctx:
|
|
|
341 |
return f0[idx, :]
|
342 |
|
343 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
344 |
+
f0 = enc.get("f0") if enc is not None else None
|
345 |
if isinstance(x, int):
|
346 |
ctx = x
|
347 |
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
|
|
351 |
else:
|
352 |
batch, head, ctx, head_dim = x.shape
|
353 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
354 |
+
|
355 |
+
f0 = enc.get("f0") if enc is not None else None
|
356 |
if f0 is not None and f0.dim() == 2:
|
357 |
if f0.shape[0] == 1:
|
358 |
f0 = f0.squeeze(0)
|
|
|
363 |
f0_mean = f0.mean()
|
364 |
theta = f0_mean + self.theta
|
365 |
else:
|
366 |
+
theta = 10000.0
|
367 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
368 |
self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
369 |
|
|
|
379 |
idx = torch.arange(ctx, device=f0.device)
|
380 |
idx = (idx * F).long().clamp(0, L - 1)
|
381 |
radius = radius[idx]
|
382 |
+
rad = radius
|
383 |
radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
|
384 |
+
radius = torch.sigmoid(radius)
|
385 |
else:
|
386 |
radius = torch.ones_like(freqs)
|
387 |
freqs = torch.polar(radius, freqs)
|
388 |
|
389 |
+
if "radius" in self.debug and self.counter % 100 == 0:
|
390 |
+
theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
|
391 |
+
print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None}")
|
392 |
+
|
393 |
+
if "rot3" in self.debug and self.counter % 100 == 0:
|
394 |
+
theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
|
395 |
+
print(f" [{layer}] [f0] {f0.shape if f0 is not None else None} [Theta] {theta_value:.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
396 |
+
|
397 |
|
398 |
if "rot3" in self.debug and self.counter % 100 == 0:
|
399 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
|
|
419 |
x1 = x1.view(orig_shape)
|
420 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
421 |
|
422 |
+
@staticmethod
|
423 |
+
def apply_rotary(x, freqs):
|
424 |
+
x1 = x[..., :freqs.shape[-1]*2]
|
425 |
+
x2 = x[..., freqs.shape[-1]*2:]
|
426 |
+
orig_shape = x1.shape
|
427 |
+
if x1.ndim == 2:
|
428 |
+
x1 = x1.unsqueeze(0)
|
429 |
+
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
430 |
+
x1 = torch.view_as_complex(x1) * freqs
|
431 |
+
x1 = torch.view_as_real(x1).flatten(-2)
|
432 |
+
x1 = x1.view(orig_shape)
|
433 |
+
return torch.cat([x1.type_as(x), x2], dim=-1)
|
434 |
+
|
435 |
class MultiheadA(nn.Module):
|
436 |
_seen = set()
|
437 |
rbf = False
|