Sin2pi commited on
Commit
db65863
·
verified ·
1 Parent(s): 10b785c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +7 -9
model.py CHANGED
@@ -274,13 +274,6 @@ class rotary(nn.Module):
274
  inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
275
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
276
 
277
- def update_base(self, f0):
278
- f0 = f0.squeeze(0).to(device, dtype)
279
- theta = f0.mean() + 1e-8
280
- inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
281
- self.inv_freq.data.copy_(inv_freq)
282
- self.theta.data.copy_(theta)
283
-
284
  def return_f0(self, f0=None):
285
  if f0 is not None:
286
  self.f0 = f0
@@ -288,6 +281,13 @@ class rotary(nn.Module):
288
  elif hasattr(self, 'f0') and self.f0 is not None:
289
  return self.f0.squeeze(0).to(device, dtype)
290
  return None
 
 
 
 
 
 
 
291
 
292
  def get_pitch_bias(self, f0):
293
  if f0 is None:
@@ -1053,12 +1053,10 @@ class Echo(nn.Module):
1053
  def update_base(self, f0):
1054
  for name, module in self.encoder.named_modules():
1055
  if isinstance(module, (rotary)):
1056
- module.update_base(f0)
1057
  module.return_f0(f0)
1058
 
1059
  for name, module in self.decoder.named_modules():
1060
  if isinstance(module, (rotary)):
1061
- module.update_base(f0)
1062
  module.return_f0(f0)
1063
 
1064
  def set_alignment_head(self, dump: bytes):
 
274
  inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
275
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
276
 
 
 
 
 
 
 
 
277
  def return_f0(self, f0=None):
278
  if f0 is not None:
279
  self.f0 = f0
 
281
  elif hasattr(self, 'f0') and self.f0 is not None:
282
  return self.f0.squeeze(0).to(device, dtype)
283
  return None
284
+
285
+ def update_base(self, f0):
286
+ f0 = self.return_f0()
287
+ theta = f0.mean() + 1e-8
288
+ inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
289
+ self.inv_freq.data.copy_(inv_freq)
290
+ self.theta.data.copy_(theta)
291
 
292
  def get_pitch_bias(self, f0):
293
  if f0 is None:
 
1053
  def update_base(self, f0):
1054
  for name, module in self.encoder.named_modules():
1055
  if isinstance(module, (rotary)):
 
1056
  module.return_f0(f0)
1057
 
1058
  for name, module in self.decoder.named_modules():
1059
  if isinstance(module, (rotary)):
 
1060
  module.return_f0(f0)
1061
 
1062
  def set_alignment_head(self, dump: bytes):