AiMimicry commited on
Commit
1462ed7
·
1 Parent(s): 30a73fc

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +47 -6
utils.py CHANGED
@@ -6,7 +6,9 @@ import argparse
6
  import logging
7
  import json
8
  import subprocess
 
9
  import random
 
10
 
11
  import librosa
12
  import numpy as np
@@ -15,6 +17,7 @@ import torch
15
  from torch.nn import functional as F
16
  from modules.commons import sequence_mask
17
  from hubert import hubert_model
 
18
  MATPLOTLIB_FLAG = False
19
 
20
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@@ -46,6 +49,21 @@ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
46
  # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
47
  # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
48
  # return f0_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def normalize_f0(f0, x_mask, uv, random_scale=True):
50
  # calculate means based on x_mask
51
  uv_sum = torch.sum(uv, dim=1, keepdim=True)
@@ -62,6 +80,19 @@ def normalize_f0(f0, x_mask, uv, random_scale=True):
62
  exit(0)
63
  return f0_norm * x_mask
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def plot_data_to_numpy(x, y):
67
  global MATPLOTLIB_FLAG
@@ -88,9 +119,6 @@ def plot_data_to_numpy(x, y):
88
 
89
 
90
  def interpolate_f0(f0):
91
- '''
92
- 对F0进行插值处理
93
- '''
94
 
95
  data = np.reshape(f0, (f0.size, 1))
96
 
@@ -120,7 +148,7 @@ def interpolate_f0(f0):
120
  for k in range(i, frame_number):
121
  ip_data[k] = last_value
122
  else:
123
- ip_data[i] = data[i]
124
  last_value = data[i]
125
 
126
  return ip_data[:,0], vuv_vector[:,0]
@@ -174,7 +202,7 @@ def f0_to_coarse(f0):
174
 
175
  f0_mel[f0_mel <= 1] = 1
176
  f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
177
- f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
178
  assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
179
  return f0_coarse
180
 
@@ -244,6 +272,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
244
  model.module.load_state_dict(new_state_dict)
245
  else:
246
  model.load_state_dict(new_state_dict)
 
247
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
248
  checkpoint_path, iteration))
249
  return model, optimizer, learning_rate, iteration
@@ -468,6 +497,19 @@ def repeat_expand_2d(content, target_len):
468
  return target
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  class HParams():
472
  def __init__(self, **kwargs):
473
  for k, v in kwargs.items():
@@ -498,4 +540,3 @@ class HParams():
498
 
499
  def __repr__(self):
500
  return self.__dict__.__repr__()
501
-
 
6
  import logging
7
  import json
8
  import subprocess
9
+ import warnings
10
  import random
11
+ import functools
12
 
13
  import librosa
14
  import numpy as np
 
17
  from torch.nn import functional as F
18
  from modules.commons import sequence_mask
19
  from hubert import hubert_model
20
+
21
  MATPLOTLIB_FLAG = False
22
 
23
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
 
49
  # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
50
  # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
51
  # return f0_norm
52
+
53
+ def deprecated(func):
54
+ """This is a decorator which can be used to mark functions
55
+ as deprecated. It will result in a warning being emitted
56
+ when the function is used."""
57
+ @functools.wraps(func)
58
+ def new_func(*args, **kwargs):
59
+ warnings.simplefilter('always', DeprecationWarning) # turn off filter
60
+ warnings.warn("Call to deprecated function {}.".format(func.__name__),
61
+ category=DeprecationWarning,
62
+ stacklevel=2)
63
+ warnings.simplefilter('default', DeprecationWarning) # reset filter
64
+ return func(*args, **kwargs)
65
+ return new_func
66
+
67
  def normalize_f0(f0, x_mask, uv, random_scale=True):
68
  # calculate means based on x_mask
69
  uv_sum = torch.sum(uv, dim=1, keepdim=True)
 
80
  exit(0)
81
  return f0_norm * x_mask
82
 
83
+ def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05):
84
+ from modules.crepe import CrepePitchExtractor
85
+ x = wav_numpy
86
+ if p_len is None:
87
+ p_len = x.shape[0]//hop_length
88
+ else:
89
+ assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
90
+
91
+ f0_min = 50
92
+ f0_max = 1100
93
+ F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold)
94
+ f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len)
95
+ return f0,uv
96
 
97
  def plot_data_to_numpy(x, y):
98
  global MATPLOTLIB_FLAG
 
119
 
120
 
121
  def interpolate_f0(f0):
 
 
 
122
 
123
  data = np.reshape(f0, (f0.size, 1))
124
 
 
148
  for k in range(i, frame_number):
149
  ip_data[k] = last_value
150
  else:
151
+ ip_data[i] = data[i] # this may not be necessary
152
  last_value = data[i]
153
 
154
  return ip_data[:,0], vuv_vector[:,0]
 
202
 
203
  f0_mel[f0_mel <= 1] = 1
204
  f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
205
+ f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
206
  assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
207
  return f0_coarse
208
 
 
272
  model.module.load_state_dict(new_state_dict)
273
  else:
274
  model.load_state_dict(new_state_dict)
275
+ print("load ")
276
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
277
  checkpoint_path, iteration))
278
  return model, optimizer, learning_rate, iteration
 
497
  return target
498
 
499
 
500
+ def mix_model(model_paths,mix_rate,mode):
501
+ mix_rate = torch.FloatTensor(mix_rate)/100
502
+ model_tem = torch.load(model_paths[0])
503
+ models = [torch.load(path)["model"] for path in model_paths]
504
+ if mode == 0:
505
+ mix_rate = F.softmax(mix_rate,dim=0)
506
+ for k in model_tem["model"].keys():
507
+ model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
508
+ for i,model in enumerate(models):
509
+ model_tem["model"][k] += model[k]*mix_rate[i]
510
+ torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
511
+ return os.path.join(os.path.curdir,"output.pth")
512
+
513
  class HParams():
514
  def __init__(self, **kwargs):
515
  for k, v in kwargs.items():
 
540
 
541
  def __repr__(self):
542
  return self.__dict__.__repr__()