Sin2pi commited on
Commit
1e6e4a2
·
verified ·
1 Parent(s): 401ff59

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +5 -7
model_hf.py CHANGED
@@ -377,14 +377,15 @@ class rotary(nn.Module):
377
  else:
378
  f0 = f0.view(-1)
379
 
380
- if f0 is not None and layer == "encoder":
381
  f0_mean = f0.mean()
382
  theta = f0_mean + self.theta
383
  else:
384
  theta = self.theta
385
  freqs = self.theta_freqs(theta)
386
  freqs = t[:, None] * freqs[None, :]
387
- if self.radii and f0 is not None and layer == "encoder":
 
388
  radius = f0.to(device, dtype)
389
  L = radius.shape[0]
390
  if L != ctx:
@@ -392,12 +393,9 @@ class rotary(nn.Module):
392
  idx = torch.arange(ctx, device=f0.device)
393
  idx = (idx * F).long().clamp(0, L - 1)
394
  radius = radius[idx]
395
-
396
- radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
397
- # radius = torch.sigmoid(radius)
398
  else:
399
- radius = torch.ones_like(freqs)
400
- freqs = torch.polar(radius, freqs)
401
 
402
  if "radius" in self.debug and self.counter % 100 == 0:
403
  theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
 
377
  else:
378
  f0 = f0.view(-1)
379
 
380
+ if f0 is not None:
381
  f0_mean = f0.mean()
382
  theta = f0_mean + self.theta
383
  else:
384
  theta = self.theta
385
  freqs = self.theta_freqs(theta)
386
  freqs = t[:, None] * freqs[None, :]
387
+
388
+ if self.radii and f0 is not None:
389
  radius = f0.to(device, dtype)
390
  L = radius.shape[0]
391
  if L != ctx:
 
393
  idx = torch.arange(ctx, device=f0.device)
394
  idx = (idx * F).long().clamp(0, L - 1)
395
  radius = radius[idx]
396
+ freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
 
 
397
  else:
398
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
 
399
 
400
  if "radius" in self.debug and self.counter % 100 == 0:
401
  theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta