shibing624 commited on
Commit
4fe860a
·
verified ·
1 Parent(s): 79ed204

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -362
utils.py DELETED
@@ -1,362 +0,0 @@
1
- import os
2
- import glob
3
- import sys
4
- import argparse
5
- import logging
6
- import json
7
- import subprocess
8
- import traceback
9
-
10
- import librosa
11
- import numpy as np
12
- from scipy.io.wavfile import read
13
- import torch
14
- import logging
15
-
16
- logging.getLogger("numba").setLevel(logging.ERROR)
17
- logging.getLogger("matplotlib").setLevel(logging.ERROR)
18
-
19
- MATPLOTLIB_FLAG = False
20
-
21
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
22
- logger = logging
23
-
24
-
25
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
26
- assert os.path.isfile(checkpoint_path)
27
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
28
- iteration = checkpoint_dict["iteration"]
29
- learning_rate = checkpoint_dict["learning_rate"]
30
- if (
31
- optimizer is not None
32
- and not skip_optimizer
33
- and checkpoint_dict["optimizer"] is not None
34
- ):
35
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
36
- saved_state_dict = checkpoint_dict["model"]
37
- if hasattr(model, "module"):
38
- state_dict = model.module.state_dict()
39
- else:
40
- state_dict = model.state_dict()
41
- new_state_dict = {}
42
- for k, v in state_dict.items():
43
- try:
44
- # assert "quantizer" not in k
45
- # print("load", k)
46
- new_state_dict[k] = saved_state_dict[k]
47
- assert saved_state_dict[k].shape == v.shape, (
48
- saved_state_dict[k].shape,
49
- v.shape,
50
- )
51
- except:
52
- traceback.print_exc()
53
- print(
54
- "error, %s is not in the checkpoint" % k
55
- ) # shape不对也会,比如text_embedding当cleaner修改时
56
- new_state_dict[k] = v
57
- if hasattr(model, "module"):
58
- model.module.load_state_dict(new_state_dict)
59
- else:
60
- model.load_state_dict(new_state_dict)
61
- print("load ")
62
- logger.info(
63
- "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
64
- )
65
- return model, optimizer, learning_rate, iteration
66
-
67
-
68
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
69
- logger.info(
70
- "Saving model and optimizer state at iteration {} to {}".format(
71
- iteration, checkpoint_path
72
- )
73
- )
74
- if hasattr(model, "module"):
75
- state_dict = model.module.state_dict()
76
- else:
77
- state_dict = model.state_dict()
78
- torch.save(
79
- {
80
- "model": state_dict,
81
- "iteration": iteration,
82
- "optimizer": optimizer.state_dict(),
83
- "learning_rate": learning_rate,
84
- },
85
- checkpoint_path,
86
- )
87
-
88
-
89
- def summarize(
90
- writer,
91
- global_step,
92
- scalars={},
93
- histograms={},
94
- images={},
95
- audios={},
96
- audio_sampling_rate=22050,
97
- ):
98
- for k, v in scalars.items():
99
- writer.add_scalar(k, v, global_step)
100
- for k, v in histograms.items():
101
- writer.add_histogram(k, v, global_step)
102
- for k, v in images.items():
103
- writer.add_image(k, v, global_step, dataformats="HWC")
104
- for k, v in audios.items():
105
- writer.add_audio(k, v, global_step, audio_sampling_rate)
106
-
107
-
108
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
109
- f_list = glob.glob(os.path.join(dir_path, regex))
110
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
111
- x = f_list[-1]
112
- print(x)
113
- return x
114
-
115
-
116
- def plot_spectrogram_to_numpy(spectrogram):
117
- global MATPLOTLIB_FLAG
118
- if not MATPLOTLIB_FLAG:
119
- import matplotlib
120
-
121
- matplotlib.use("Agg")
122
- MATPLOTLIB_FLAG = True
123
- mpl_logger = logging.getLogger("matplotlib")
124
- mpl_logger.setLevel(logging.WARNING)
125
- import matplotlib.pylab as plt
126
- import numpy as np
127
-
128
- fig, ax = plt.subplots(figsize=(10, 2))
129
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
130
- plt.colorbar(im, ax=ax)
131
- plt.xlabel("Frames")
132
- plt.ylabel("Channels")
133
- plt.tight_layout()
134
-
135
- fig.canvas.draw()
136
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
137
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
138
- plt.close()
139
- return data
140
-
141
-
142
- def plot_alignment_to_numpy(alignment, info=None):
143
- global MATPLOTLIB_FLAG
144
- if not MATPLOTLIB_FLAG:
145
- import matplotlib
146
-
147
- matplotlib.use("Agg")
148
- MATPLOTLIB_FLAG = True
149
- mpl_logger = logging.getLogger("matplotlib")
150
- mpl_logger.setLevel(logging.WARNING)
151
- import matplotlib.pylab as plt
152
- import numpy as np
153
-
154
- fig, ax = plt.subplots(figsize=(6, 4))
155
- im = ax.imshow(
156
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
157
- )
158
- fig.colorbar(im, ax=ax)
159
- xlabel = "Decoder timestep"
160
- if info is not None:
161
- xlabel += "\n\n" + info
162
- plt.xlabel(xlabel)
163
- plt.ylabel("Encoder timestep")
164
- plt.tight_layout()
165
-
166
- fig.canvas.draw()
167
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
168
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
169
- plt.close()
170
- return data
171
-
172
-
173
- def load_wav_to_torch(full_path):
174
- data, sampling_rate = librosa.load(full_path, sr=None)
175
- return torch.FloatTensor(data), sampling_rate
176
-
177
-
178
- def load_filepaths_and_text(filename, split="|"):
179
- with open(filename, encoding="utf-8") as f:
180
- filepaths_and_text = [line.strip().split(split) for line in f]
181
- return filepaths_and_text
182
-
183
-
184
- def get_hparams(init=True, stage=1):
185
- parser = argparse.ArgumentParser()
186
- parser.add_argument(
187
- "-c",
188
- "--config",
189
- type=str,
190
- default="./configs/s2.json",
191
- help="JSON file for configuration",
192
- )
193
- parser.add_argument(
194
- "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
195
- )
196
- parser.add_argument(
197
- "-rs",
198
- "--resume_step",
199
- type=int,
200
- required=False,
201
- default=None,
202
- help="resume step",
203
- )
204
- # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
205
- # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
206
- # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
207
-
208
- args = parser.parse_args()
209
-
210
- config_path = args.config
211
- with open(config_path, "r") as f:
212
- data = f.read()
213
- config = json.loads(data)
214
-
215
- hparams = HParams(**config)
216
- hparams.pretrain = args.pretrain
217
- hparams.resume_step = args.resume_step
218
- # hparams.data.exp_dir = args.exp_dir
219
- if stage == 1:
220
- model_dir = hparams.s1_ckpt_dir
221
- else:
222
- model_dir = hparams.s2_ckpt_dir
223
- config_save_path = os.path.join(model_dir, "config.json")
224
-
225
- if not os.path.exists(model_dir):
226
- os.makedirs(model_dir)
227
-
228
- with open(config_save_path, "w") as f:
229
- f.write(data)
230
- return hparams
231
-
232
-
233
- def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
234
- """Freeing up space by deleting saved ckpts
235
-
236
- Arguments:
237
- path_to_models -- Path to the model directory
238
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
239
- sort_by_time -- True -> chronologically delete ckpts
240
- False -> lexicographically delete ckpts
241
- """
242
- import re
243
-
244
- ckpts_files = [
245
- f
246
- for f in os.listdir(path_to_models)
247
- if os.path.isfile(os.path.join(path_to_models, f))
248
- ]
249
- name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
250
- time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
251
- sort_key = time_key if sort_by_time else name_key
252
- x_sorted = lambda _x: sorted(
253
- [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
254
- key=sort_key,
255
- )
256
- to_del = [
257
- os.path.join(path_to_models, fn)
258
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
259
- ]
260
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
261
- del_routine = lambda x: [os.remove(x), del_info(x)]
262
- rs = [del_routine(fn) for fn in to_del]
263
-
264
-
265
- def get_hparams_from_dir(model_dir):
266
- config_save_path = os.path.join(model_dir, "config.json")
267
- with open(config_save_path, "r") as f:
268
- data = f.read()
269
- config = json.loads(data)
270
-
271
- hparams = HParams(**config)
272
- hparams.model_dir = model_dir
273
- return hparams
274
-
275
-
276
- def get_hparams_from_file(config_path):
277
- with open(config_path, "r") as f:
278
- data = f.read()
279
- config = json.loads(data)
280
-
281
- hparams = HParams(**config)
282
- return hparams
283
-
284
-
285
- def check_git_hash(model_dir):
286
- source_dir = os.path.dirname(os.path.realpath(__file__))
287
- if not os.path.exists(os.path.join(source_dir, ".git")):
288
- logger.warn(
289
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
290
- source_dir
291
- )
292
- )
293
- return
294
-
295
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
296
-
297
- path = os.path.join(model_dir, "githash")
298
- if os.path.exists(path):
299
- saved_hash = open(path).read()
300
- if saved_hash != cur_hash:
301
- logger.warn(
302
- "git hash values are different. {}(saved) != {}(current)".format(
303
- saved_hash[:8], cur_hash[:8]
304
- )
305
- )
306
- else:
307
- open(path, "w").write(cur_hash)
308
-
309
-
310
- def get_logger(model_dir, filename="train.log"):
311
- global logger
312
- logger = logging.getLogger(os.path.basename(model_dir))
313
- logger.setLevel(logging.DEBUG)
314
-
315
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
316
- if not os.path.exists(model_dir):
317
- os.makedirs(model_dir)
318
- h = logging.FileHandler(os.path.join(model_dir, filename))
319
- h.setLevel(logging.DEBUG)
320
- h.setFormatter(formatter)
321
- logger.addHandler(h)
322
- return logger
323
-
324
-
325
- class HParams:
326
- def __init__(self, **kwargs):
327
- for k, v in kwargs.items():
328
- if type(v) == dict:
329
- v = HParams(**v)
330
- self[k] = v
331
-
332
- def keys(self):
333
- return self.__dict__.keys()
334
-
335
- def items(self):
336
- return self.__dict__.items()
337
-
338
- def values(self):
339
- return self.__dict__.values()
340
-
341
- def __len__(self):
342
- return len(self.__dict__)
343
-
344
- def __getitem__(self, key):
345
- return getattr(self, key)
346
-
347
- def __setitem__(self, key, value):
348
- return setattr(self, key, value)
349
-
350
- def __contains__(self, key):
351
- return key in self.__dict__
352
-
353
- def __repr__(self):
354
- return self.__dict__.__repr__()
355
-
356
-
357
- if __name__ == "__main__":
358
- print(
359
- load_wav_to_torch(
360
- "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
361
- )
362
- )