Update model.py
Browse files
model.py
CHANGED
@@ -274,13 +274,6 @@ class rotary(nn.Module):
|
|
274 |
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
275 |
self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
|
276 |
|
277 |
-
def update_base(self, f0):
|
278 |
-
f0 = f0.squeeze(0).to(device, dtype)
|
279 |
-
theta = f0.mean() + 1e-8
|
280 |
-
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
281 |
-
self.inv_freq.data.copy_(inv_freq)
|
282 |
-
self.theta.data.copy_(theta)
|
283 |
-
|
284 |
def return_f0(self, f0=None):
|
285 |
if f0 is not None:
|
286 |
self.f0 = f0
|
@@ -288,6 +281,13 @@ class rotary(nn.Module):
|
|
288 |
elif hasattr(self, 'f0') and self.f0 is not None:
|
289 |
return self.f0.squeeze(0).to(device, dtype)
|
290 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
def get_pitch_bias(self, f0):
|
293 |
if f0 is None:
|
@@ -1053,12 +1053,10 @@ class Echo(nn.Module):
|
|
1053 |
def update_base(self, f0):
|
1054 |
for name, module in self.encoder.named_modules():
|
1055 |
if isinstance(module, (rotary)):
|
1056 |
-
module.update_base(f0)
|
1057 |
module.return_f0(f0)
|
1058 |
|
1059 |
for name, module in self.decoder.named_modules():
|
1060 |
if isinstance(module, (rotary)):
|
1061 |
-
module.update_base(f0)
|
1062 |
module.return_f0(f0)
|
1063 |
|
1064 |
def set_alignment_head(self, dump: bytes):
|
|
|
274 |
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
275 |
self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
def return_f0(self, f0=None):
|
278 |
if f0 is not None:
|
279 |
self.f0 = f0
|
|
|
281 |
elif hasattr(self, 'f0') and self.f0 is not None:
|
282 |
return self.f0.squeeze(0).to(device, dtype)
|
283 |
return None
|
284 |
+
|
285 |
+
def update_base(self, f0):
|
286 |
+
f0 = self.return_f0()
|
287 |
+
theta = f0.mean() + 1e-8
|
288 |
+
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
289 |
+
self.inv_freq.data.copy_(inv_freq)
|
290 |
+
self.theta.data.copy_(theta)
|
291 |
|
292 |
def get_pitch_bias(self, f0):
|
293 |
if f0 is None:
|
|
|
1053 |
def update_base(self, f0):
|
1054 |
for name, module in self.encoder.named_modules():
|
1055 |
if isinstance(module, (rotary)):
|
|
|
1056 |
module.return_f0(f0)
|
1057 |
|
1058 |
for name, module in self.decoder.named_modules():
|
1059 |
if isinstance(module, (rotary)):
|
|
|
1060 |
module.return_f0(f0)
|
1061 |
|
1062 |
def set_alignment_head(self, dump: bytes):
|