Sin2pi commited on
Commit
e9bf4ee
·
verified ·
1 Parent(s): bebb811

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +30 -8
model.py CHANGED
@@ -271,6 +271,7 @@ class rotary(nn.Module):
271
  def return_f0(self, f0=None):
272
  if f0 is not None:
273
  self.f0 = f0
 
274
  return f0.squeeze(0).to(device, dtype)
275
  elif hasattr(self, 'f0') and self.f0 is not None:
276
  return self.f0.squeeze(0).to(device, dtype)
@@ -303,18 +304,17 @@ class rotary(nn.Module):
303
  return f0.to(device=device, dtype=dtype)
304
 
305
  def synth_f0(self, f0, ctx):
306
- f0 = self.f0proj(f0)
307
-
308
  if f0.dim() == 1:
309
  length = f0.shape[0]
310
  if length == ctx:
311
  return f0
312
  frames = length / ctx
313
  idx = torch.arange(ctx, device=f0.device)
314
- # return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
315
  return f0[idx]
316
 
317
  def align_f0(self, ctx, f0):
 
318
  if f0.dim() == 3:
319
  batch, length, dims = f0.shape
320
  if length == ctx:
@@ -341,6 +341,7 @@ class rotary(nn.Module):
341
  return f0[idx, :]
342
 
343
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
 
344
  if isinstance(x, int):
345
  ctx = x
346
  elif isinstance(x, torch.Tensor) and x.ndim == 2:
@@ -350,8 +351,8 @@ class rotary(nn.Module):
350
  else:
351
  batch, head, ctx, head_dim = x.shape
352
  t = torch.arange(ctx, device=device, dtype=dtype)
353
-
354
- f0 = enc.get("f0") if enc is not None else None
355
  if f0 is not None and f0.dim() == 2:
356
  if f0.shape[0] == 1:
357
  f0 = f0.squeeze(0)
@@ -362,7 +363,7 @@ class rotary(nn.Module):
362
  f0_mean = f0.mean()
363
  theta = f0_mean + self.theta
364
  else:
365
- theta = self.theta
366
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
367
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
368
 
@@ -378,13 +379,21 @@ class rotary(nn.Module):
378
  idx = torch.arange(ctx, device=f0.device)
379
  idx = (idx * F).long().clamp(0, L - 1)
380
  radius = radius[idx]
 
381
  radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
 
382
  else:
383
  radius = torch.ones_like(freqs)
384
  freqs = torch.polar(radius, freqs)
385
 
386
- if "rot1" in self.debug and self.counter % 100 == 0:
387
- print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
 
 
 
 
 
 
388
 
389
  if "rot3" in self.debug and self.counter % 100 == 0:
390
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
@@ -410,6 +419,19 @@ class rotary(nn.Module):
410
  x1 = x1.view(orig_shape)
411
  return torch.cat([x1.type_as(x), x2], dim=-1)
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  class MultiheadA(nn.Module):
414
  _seen = set()
415
  rbf = False
 
271
  def return_f0(self, f0=None):
272
  if f0 is not None:
273
  self.f0 = f0
274
+ self.update_base(f0)
275
  return f0.squeeze(0).to(device, dtype)
276
  elif hasattr(self, 'f0') and self.f0 is not None:
277
  return self.f0.squeeze(0).to(device, dtype)
 
304
  return f0.to(device=device, dtype=dtype)
305
 
306
  def synth_f0(self, f0, ctx):
307
+ # f0 = self.f0proj(f0)
 
308
  if f0.dim() == 1:
309
  length = f0.shape[0]
310
  if length == ctx:
311
  return f0
312
  frames = length / ctx
313
  idx = torch.arange(ctx, device=f0.device)
 
314
  return f0[idx]
315
 
316
  def align_f0(self, ctx, f0):
317
+ f0 = self.f0proj(f0)
318
  if f0.dim() == 3:
319
  batch, length, dims = f0.shape
320
  if length == ctx:
 
341
  return f0[idx, :]
342
 
343
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
344
+ f0 = enc.get("f0") if enc is not None else None
345
  if isinstance(x, int):
346
  ctx = x
347
  elif isinstance(x, torch.Tensor) and x.ndim == 2:
 
351
  else:
352
  batch, head, ctx, head_dim = x.shape
353
  t = torch.arange(ctx, device=device, dtype=dtype)
354
+
355
+ f0 = enc.get("f0") if enc is not None else None
356
  if f0 is not None and f0.dim() == 2:
357
  if f0.shape[0] == 1:
358
  f0 = f0.squeeze(0)
 
363
  f0_mean = f0.mean()
364
  theta = f0_mean + self.theta
365
  else:
366
+ theta = 10000.0
367
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
368
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
369
 
 
379
  idx = torch.arange(ctx, device=f0.device)
380
  idx = (idx * F).long().clamp(0, L - 1)
381
  radius = radius[idx]
382
+ rad = radius
383
  radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
384
+ radius = torch.sigmoid(radius)
385
  else:
386
  radius = torch.ones_like(freqs)
387
  freqs = torch.polar(radius, freqs)
388
 
389
+ if "radius" in self.debug and self.counter % 100 == 0:
390
+ theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
391
+ print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None}")
392
+
393
+ if "rot3" in self.debug and self.counter % 100 == 0:
394
+ theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
395
+ print(f" [{layer}] [f0] {f0.shape if f0 is not None else None} [Theta] {theta_value:.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
396
+
397
 
398
  if "rot3" in self.debug and self.counter % 100 == 0:
399
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
 
419
  x1 = x1.view(orig_shape)
420
  return torch.cat([x1.type_as(x), x2], dim=-1)
421
 
422
+ @staticmethod
423
+ def apply_rotary(x, freqs):
424
+ x1 = x[..., :freqs.shape[-1]*2]
425
+ x2 = x[..., freqs.shape[-1]*2:]
426
+ orig_shape = x1.shape
427
+ if x1.ndim == 2:
428
+ x1 = x1.unsqueeze(0)
429
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
430
+ x1 = torch.view_as_complex(x1) * freqs
431
+ x1 = torch.view_as_real(x1).flatten(-2)
432
+ x1 = x1.view(orig_shape)
433
+ return torch.cat([x1.type_as(x), x2], dim=-1)
434
+
435
  class MultiheadA(nn.Module):
436
  _seen = set()
437
  rbf = False