AkitoP commited on
Commit
83b3a24
1 Parent(s): 4536f60

Update GPT_SoVITS/utils.py

Browse files
Files changed (1) hide show
  1. GPT_SoVITS/utils.py +372 -371
GPT_SoVITS/utils.py CHANGED
@@ -1,371 +1,372 @@
1
- import os
2
- import glob
3
- import sys
4
- import argparse
5
- import logging
6
- import json
7
- import subprocess
8
- import traceback
9
-
10
- import librosa
11
- import numpy as np
12
- from scipy.io.wavfile import read
13
- import torch
14
- import logging
15
-
16
- logging.getLogger("numba").setLevel(logging.ERROR)
17
- logging.getLogger("matplotlib").setLevel(logging.ERROR)
18
-
19
- MATPLOTLIB_FLAG = False
20
-
21
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
22
- logger = logging
23
-
24
-
25
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
26
- assert os.path.isfile(checkpoint_path)
27
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
28
- iteration = checkpoint_dict["iteration"]
29
- learning_rate = checkpoint_dict["learning_rate"]
30
- if (
31
- optimizer is not None
32
- and not skip_optimizer
33
- and checkpoint_dict["optimizer"] is not None
34
- ):
35
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
36
- saved_state_dict = checkpoint_dict["model"]
37
- if hasattr(model, "module"):
38
- state_dict = model.module.state_dict()
39
- else:
40
- state_dict = model.state_dict()
41
- new_state_dict = {}
42
- for k, v in state_dict.items():
43
- try:
44
- # assert "quantizer" not in k
45
- # print("load", k)
46
- new_state_dict[k] = saved_state_dict[k]
47
- assert saved_state_dict[k].shape == v.shape, (
48
- saved_state_dict[k].shape,
49
- v.shape,
50
- )
51
- except:
52
- traceback.print_exc()
53
- print(
54
- "error, %s is not in the checkpoint" % k
55
- ) # shape不对也会,比如text_embedding当cleaner修改时
56
- new_state_dict[k] = v
57
- if hasattr(model, "module"):
58
- model.module.load_state_dict(new_state_dict)
59
- else:
60
- model.load_state_dict(new_state_dict)
61
- print("load ")
62
- logger.info(
63
- "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
64
- )
65
- return model, optimizer, learning_rate, iteration
66
-
67
- from time import time as ttime
68
- import shutil
69
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
70
- dir=os.path.dirname(path)
71
- name=os.path.basename(path)
72
- tmp_path="%s.pth"%(ttime())
73
- torch.save(fea,tmp_path)
74
- shutil.move(tmp_path,"%s/%s"%(dir,name))
75
-
76
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
77
- logger.info(
78
- "Saving model and optimizer state at iteration {} to {}".format(
79
- iteration, checkpoint_path
80
- )
81
- )
82
- if hasattr(model, "module"):
83
- state_dict = model.module.state_dict()
84
- else:
85
- state_dict = model.state_dict()
86
- # torch.save(
87
- my_save(
88
- {
89
- "model": state_dict,
90
- "iteration": iteration,
91
- "optimizer": optimizer.state_dict(),
92
- "learning_rate": learning_rate,
93
- },
94
- checkpoint_path,
95
- )
96
-
97
-
98
- def summarize(
99
- writer,
100
- global_step,
101
- scalars={},
102
- histograms={},
103
- images={},
104
- audios={},
105
- audio_sampling_rate=22050,
106
- ):
107
- for k, v in scalars.items():
108
- writer.add_scalar(k, v, global_step)
109
- for k, v in histograms.items():
110
- writer.add_histogram(k, v, global_step)
111
- for k, v in images.items():
112
- writer.add_image(k, v, global_step, dataformats="HWC")
113
- for k, v in audios.items():
114
- writer.add_audio(k, v, global_step, audio_sampling_rate)
115
-
116
-
117
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
- f_list = glob.glob(os.path.join(dir_path, regex))
119
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
- x = f_list[-1]
121
- print(x)
122
- return x
123
-
124
-
125
- def plot_spectrogram_to_numpy(spectrogram):
126
- global MATPLOTLIB_FLAG
127
- if not MATPLOTLIB_FLAG:
128
- import matplotlib
129
-
130
- matplotlib.use("Agg")
131
- MATPLOTLIB_FLAG = True
132
- mpl_logger = logging.getLogger("matplotlib")
133
- mpl_logger.setLevel(logging.WARNING)
134
- import matplotlib.pylab as plt
135
- import numpy as np
136
-
137
- fig, ax = plt.subplots(figsize=(10, 2))
138
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
139
- plt.colorbar(im, ax=ax)
140
- plt.xlabel("Frames")
141
- plt.ylabel("Channels")
142
- plt.tight_layout()
143
-
144
- fig.canvas.draw()
145
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
146
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
- plt.close()
148
- return data
149
-
150
-
151
- def plot_alignment_to_numpy(alignment, info=None):
152
- global MATPLOTLIB_FLAG
153
- if not MATPLOTLIB_FLAG:
154
- import matplotlib
155
-
156
- matplotlib.use("Agg")
157
- MATPLOTLIB_FLAG = True
158
- mpl_logger = logging.getLogger("matplotlib")
159
- mpl_logger.setLevel(logging.WARNING)
160
- import matplotlib.pylab as plt
161
- import numpy as np
162
-
163
- fig, ax = plt.subplots(figsize=(6, 4))
164
- im = ax.imshow(
165
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
166
- )
167
- fig.colorbar(im, ax=ax)
168
- xlabel = "Decoder timestep"
169
- if info is not None:
170
- xlabel += "\n\n" + info
171
- plt.xlabel(xlabel)
172
- plt.ylabel("Encoder timestep")
173
- plt.tight_layout()
174
-
175
- fig.canvas.draw()
176
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
177
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
178
- plt.close()
179
- return data
180
-
181
-
182
- def load_wav_to_torch(full_path):
183
- data, sampling_rate = librosa.load(full_path, sr=None)
184
- return torch.FloatTensor(data), sampling_rate
185
-
186
-
187
- def load_filepaths_and_text(filename, split="|"):
188
- with open(filename, encoding="utf-8") as f:
189
- filepaths_and_text = [line.strip().split(split) for line in f]
190
- return filepaths_and_text
191
-
192
-
193
- def get_hparams(init=True, stage=1):
194
- parser = argparse.ArgumentParser()
195
- parser.add_argument(
196
- "-c",
197
- "--config",
198
- type=str,
199
- default="./configs/s2.json",
200
- help="JSON file for configuration",
201
- )
202
- parser.add_argument(
203
- "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
204
- )
205
- parser.add_argument(
206
- "-rs",
207
- "--resume_step",
208
- type=int,
209
- required=False,
210
- default=None,
211
- help="resume step",
212
- )
213
- # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
214
- # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
215
- # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
216
-
217
- args = parser.parse_args()
218
-
219
- config_path = args.config
220
- with open(config_path, "r") as f:
221
- data = f.read()
222
- config = json.loads(data)
223
-
224
- hparams = HParams(**config)
225
- hparams.pretrain = args.pretrain
226
- hparams.resume_step = args.resume_step
227
- # hparams.data.exp_dir = args.exp_dir
228
- if stage == 1:
229
- model_dir = hparams.s1_ckpt_dir
230
- else:
231
- model_dir = hparams.s2_ckpt_dir
232
- config_save_path = os.path.join(model_dir, "config.json")
233
-
234
- if not os.path.exists(model_dir):
235
- os.makedirs(model_dir)
236
-
237
- with open(config_save_path, "w") as f:
238
- f.write(data)
239
- return hparams
240
-
241
-
242
- def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
243
- """Freeing up space by deleting saved ckpts
244
-
245
- Arguments:
246
- path_to_models -- Path to the model directory
247
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
248
- sort_by_time -- True -> chronologically delete ckpts
249
- False -> lexicographically delete ckpts
250
- """
251
- import re
252
-
253
- ckpts_files = [
254
- f
255
- for f in os.listdir(path_to_models)
256
- if os.path.isfile(os.path.join(path_to_models, f))
257
- ]
258
- name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
259
- time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
260
- sort_key = time_key if sort_by_time else name_key
261
- x_sorted = lambda _x: sorted(
262
- [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
263
- key=sort_key,
264
- )
265
- to_del = [
266
- os.path.join(path_to_models, fn)
267
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
268
- ]
269
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
270
- del_routine = lambda x: [os.remove(x), del_info(x)]
271
- rs = [del_routine(fn) for fn in to_del]
272
-
273
-
274
- def get_hparams_from_dir(model_dir):
275
- config_save_path = os.path.join(model_dir, "config.json")
276
- with open(config_save_path, "r") as f:
277
- data = f.read()
278
- config = json.loads(data)
279
-
280
- hparams = HParams(**config)
281
- hparams.model_dir = model_dir
282
- return hparams
283
-
284
-
285
- def get_hparams_from_file(config_path):
286
- with open(config_path, "r") as f:
287
- data = f.read()
288
- config = json.loads(data)
289
-
290
- hparams = HParams(**config)
291
- return hparams
292
-
293
-
294
- def check_git_hash(model_dir):
295
- source_dir = os.path.dirname(os.path.realpath(__file__))
296
- if not os.path.exists(os.path.join(source_dir, ".git")):
297
- logger.warn(
298
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
299
- source_dir
300
- )
301
- )
302
- return
303
-
304
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
305
-
306
- path = os.path.join(model_dir, "githash")
307
- if os.path.exists(path):
308
- saved_hash = open(path).read()
309
- if saved_hash != cur_hash:
310
- logger.warn(
311
- "git hash values are different. {}(saved) != {}(current)".format(
312
- saved_hash[:8], cur_hash[:8]
313
- )
314
- )
315
- else:
316
- open(path, "w").write(cur_hash)
317
-
318
-
319
- def get_logger(model_dir, filename="train.log"):
320
- global logger
321
- logger = logging.getLogger(os.path.basename(model_dir))
322
- logger.setLevel(logging.DEBUG)
323
-
324
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
325
- if not os.path.exists(model_dir):
326
- os.makedirs(model_dir)
327
- h = logging.FileHandler(os.path.join(model_dir, filename))
328
- h.setLevel(logging.DEBUG)
329
- h.setFormatter(formatter)
330
- logger.addHandler(h)
331
- return logger
332
-
333
-
334
- class HParams:
335
- def __init__(self, **kwargs):
336
- for k, v in kwargs.items():
337
- if type(v) == dict:
338
- v = HParams(**v)
339
- self[k] = v
340
-
341
- def keys(self):
342
- return self.__dict__.keys()
343
-
344
- def items(self):
345
- return self.__dict__.items()
346
-
347
- def values(self):
348
- return self.__dict__.values()
349
-
350
- def __len__(self):
351
- return len(self.__dict__)
352
-
353
- def __getitem__(self, key):
354
- return getattr(self, key)
355
-
356
- def __setitem__(self, key, value):
357
- return setattr(self, key, value)
358
-
359
- def __contains__(self, key):
360
- return key in self.__dict__
361
-
362
- def __repr__(self):
363
- return self.__dict__.__repr__()
364
-
365
-
366
- if __name__ == "__main__":
367
- print(
368
- load_wav_to_torch(
369
- "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
370
- )
371
- )
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import traceback
9
+ now_dir = os.getcwd()
10
+ sys.path.insert(0, now_dir)
11
+ import librosa
12
+ import numpy as np
13
+ from scipy.io.wavfile import read
14
+ import torch
15
+ import logging
16
+
17
+ logging.getLogger("numba").setLevel(logging.ERROR)
18
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
19
+
20
+ MATPLOTLIB_FLAG = False
21
+
22
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
23
+ logger = logging
24
+
25
+
26
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
27
+ assert os.path.isfile(checkpoint_path)
28
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
29
+ iteration = checkpoint_dict["iteration"]
30
+ learning_rate = checkpoint_dict["learning_rate"]
31
+ if (
32
+ optimizer is not None
33
+ and not skip_optimizer
34
+ and checkpoint_dict["optimizer"] is not None
35
+ ):
36
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
37
+ saved_state_dict = checkpoint_dict["model"]
38
+ if hasattr(model, "module"):
39
+ state_dict = model.module.state_dict()
40
+ else:
41
+ state_dict = model.state_dict()
42
+ new_state_dict = {}
43
+ for k, v in state_dict.items():
44
+ try:
45
+ # assert "quantizer" not in k
46
+ # print("load", k)
47
+ new_state_dict[k] = saved_state_dict[k]
48
+ assert saved_state_dict[k].shape == v.shape, (
49
+ saved_state_dict[k].shape,
50
+ v.shape,
51
+ )
52
+ except:
53
+ traceback.print_exc()
54
+ print(
55
+ "error, %s is not in the checkpoint" % k
56
+ ) # shape不对也会,比如text_embedding当cleaner修改时
57
+ new_state_dict[k] = v
58
+ if hasattr(model, "module"):
59
+ model.module.load_state_dict(new_state_dict)
60
+ else:
61
+ model.load_state_dict(new_state_dict)
62
+ print("load ")
63
+ logger.info(
64
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
65
+ )
66
+ return model, optimizer, learning_rate, iteration
67
+
68
+ from time import time as ttime
69
+ import shutil
70
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
71
+ dir=os.path.dirname(path)
72
+ name=os.path.basename(path)
73
+ tmp_path="%s.pth"%(ttime())
74
+ torch.save(fea,tmp_path)
75
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
76
+
77
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
78
+ logger.info(
79
+ "Saving model and optimizer state at iteration {} to {}".format(
80
+ iteration, checkpoint_path
81
+ )
82
+ )
83
+ if hasattr(model, "module"):
84
+ state_dict = model.module.state_dict()
85
+ else:
86
+ state_dict = model.state_dict()
87
+ # torch.save(
88
+ my_save(
89
+ {
90
+ "model": state_dict,
91
+ "iteration": iteration,
92
+ "optimizer": optimizer.state_dict(),
93
+ "learning_rate": learning_rate,
94
+ },
95
+ checkpoint_path,
96
+ )
97
+
98
+
99
+ def summarize(
100
+ writer,
101
+ global_step,
102
+ scalars={},
103
+ histograms={},
104
+ images={},
105
+ audios={},
106
+ audio_sampling_rate=22050,
107
+ ):
108
+ for k, v in scalars.items():
109
+ writer.add_scalar(k, v, global_step)
110
+ for k, v in histograms.items():
111
+ writer.add_histogram(k, v, global_step)
112
+ for k, v in images.items():
113
+ writer.add_image(k, v, global_step, dataformats="HWC")
114
+ for k, v in audios.items():
115
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
116
+
117
+
118
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
119
+ f_list = glob.glob(os.path.join(dir_path, regex))
120
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
121
+ x = f_list[-1]
122
+ print(x)
123
+ return x
124
+
125
+
126
+ def plot_spectrogram_to_numpy(spectrogram):
127
+ global MATPLOTLIB_FLAG
128
+ if not MATPLOTLIB_FLAG:
129
+ import matplotlib
130
+
131
+ matplotlib.use("Agg")
132
+ MATPLOTLIB_FLAG = True
133
+ mpl_logger = logging.getLogger("matplotlib")
134
+ mpl_logger.setLevel(logging.WARNING)
135
+ import matplotlib.pylab as plt
136
+ import numpy as np
137
+
138
+ fig, ax = plt.subplots(figsize=(10, 2))
139
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
140
+ plt.colorbar(im, ax=ax)
141
+ plt.xlabel("Frames")
142
+ plt.ylabel("Channels")
143
+ plt.tight_layout()
144
+
145
+ fig.canvas.draw()
146
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
147
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
148
+ plt.close()
149
+ return data
150
+
151
+
152
+ def plot_alignment_to_numpy(alignment, info=None):
153
+ global MATPLOTLIB_FLAG
154
+ if not MATPLOTLIB_FLAG:
155
+ import matplotlib
156
+
157
+ matplotlib.use("Agg")
158
+ MATPLOTLIB_FLAG = True
159
+ mpl_logger = logging.getLogger("matplotlib")
160
+ mpl_logger.setLevel(logging.WARNING)
161
+ import matplotlib.pylab as plt
162
+ import numpy as np
163
+
164
+ fig, ax = plt.subplots(figsize=(6, 4))
165
+ im = ax.imshow(
166
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
167
+ )
168
+ fig.colorbar(im, ax=ax)
169
+ xlabel = "Decoder timestep"
170
+ if info is not None:
171
+ xlabel += "\n\n" + info
172
+ plt.xlabel(xlabel)
173
+ plt.ylabel("Encoder timestep")
174
+ plt.tight_layout()
175
+
176
+ fig.canvas.draw()
177
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
178
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
179
+ plt.close()
180
+ return data
181
+
182
+
183
+ def load_wav_to_torch(full_path):
184
+ data, sampling_rate = librosa.load(full_path, sr=None)
185
+ return torch.FloatTensor(data), sampling_rate
186
+
187
+
188
+ def load_filepaths_and_text(filename, split="|"):
189
+ with open(filename, encoding="utf-8") as f:
190
+ filepaths_and_text = [line.strip().split(split) for line in f]
191
+ return filepaths_and_text
192
+
193
+
194
+ def get_hparams(init=True, stage=1):
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument(
197
+ "-c",
198
+ "--config",
199
+ type=str,
200
+ default="./configs/s2.json",
201
+ help="JSON file for configuration",
202
+ )
203
+ parser.add_argument(
204
+ "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
205
+ )
206
+ parser.add_argument(
207
+ "-rs",
208
+ "--resume_step",
209
+ type=int,
210
+ required=False,
211
+ default=None,
212
+ help="resume step",
213
+ )
214
+ # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
215
+ # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
216
+ # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
217
+
218
+ args = parser.parse_args()
219
+
220
+ config_path = args.config
221
+ with open(config_path, "r") as f:
222
+ data = f.read()
223
+ config = json.loads(data)
224
+
225
+ hparams = HParams(**config)
226
+ hparams.pretrain = args.pretrain
227
+ hparams.resume_step = args.resume_step
228
+ # hparams.data.exp_dir = args.exp_dir
229
+ if stage == 1:
230
+ model_dir = hparams.s1_ckpt_dir
231
+ else:
232
+ model_dir = hparams.s2_ckpt_dir
233
+ config_save_path = os.path.join(model_dir, "config.json")
234
+
235
+ if not os.path.exists(model_dir):
236
+ os.makedirs(model_dir)
237
+
238
+ with open(config_save_path, "w") as f:
239
+ f.write(data)
240
+ return hparams
241
+
242
+
243
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
244
+ """Freeing up space by deleting saved ckpts
245
+
246
+ Arguments:
247
+ path_to_models -- Path to the model directory
248
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
249
+ sort_by_time -- True -> chronologically delete ckpts
250
+ False -> lexicographically delete ckpts
251
+ """
252
+ import re
253
+
254
+ ckpts_files = [
255
+ f
256
+ for f in os.listdir(path_to_models)
257
+ if os.path.isfile(os.path.join(path_to_models, f))
258
+ ]
259
+ name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
260
+ time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
261
+ sort_key = time_key if sort_by_time else name_key
262
+ x_sorted = lambda _x: sorted(
263
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
264
+ key=sort_key,
265
+ )
266
+ to_del = [
267
+ os.path.join(path_to_models, fn)
268
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
269
+ ]
270
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
271
+ del_routine = lambda x: [os.remove(x), del_info(x)]
272
+ rs = [del_routine(fn) for fn in to_del]
273
+
274
+
275
+ def get_hparams_from_dir(model_dir):
276
+ config_save_path = os.path.join(model_dir, "config.json")
277
+ with open(config_save_path, "r") as f:
278
+ data = f.read()
279
+ config = json.loads(data)
280
+
281
+ hparams = HParams(**config)
282
+ hparams.model_dir = model_dir
283
+ return hparams
284
+
285
+
286
+ def get_hparams_from_file(config_path):
287
+ with open(config_path, "r") as f:
288
+ data = f.read()
289
+ config = json.loads(data)
290
+
291
+ hparams = HParams(**config)
292
+ return hparams
293
+
294
+
295
+ def check_git_hash(model_dir):
296
+ source_dir = os.path.dirname(os.path.realpath(__file__))
297
+ if not os.path.exists(os.path.join(source_dir, ".git")):
298
+ logger.warn(
299
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
300
+ source_dir
301
+ )
302
+ )
303
+ return
304
+
305
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
306
+
307
+ path = os.path.join(model_dir, "githash")
308
+ if os.path.exists(path):
309
+ saved_hash = open(path).read()
310
+ if saved_hash != cur_hash:
311
+ logger.warn(
312
+ "git hash values are different. {}(saved) != {}(current)".format(
313
+ saved_hash[:8], cur_hash[:8]
314
+ )
315
+ )
316
+ else:
317
+ open(path, "w").write(cur_hash)
318
+
319
+
320
+ def get_logger(model_dir, filename="train.log"):
321
+ global logger
322
+ logger = logging.getLogger(os.path.basename(model_dir))
323
+ logger.setLevel(logging.DEBUG)
324
+
325
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
326
+ if not os.path.exists(model_dir):
327
+ os.makedirs(model_dir)
328
+ h = logging.FileHandler(os.path.join(model_dir, filename))
329
+ h.setLevel(logging.DEBUG)
330
+ h.setFormatter(formatter)
331
+ logger.addHandler(h)
332
+ return logger
333
+
334
+
335
+ class HParams:
336
+ def __init__(self, **kwargs):
337
+ for k, v in kwargs.items():
338
+ if type(v) == dict:
339
+ v = HParams(**v)
340
+ self[k] = v
341
+
342
+ def keys(self):
343
+ return self.__dict__.keys()
344
+
345
+ def items(self):
346
+ return self.__dict__.items()
347
+
348
+ def values(self):
349
+ return self.__dict__.values()
350
+
351
+ def __len__(self):
352
+ return len(self.__dict__)
353
+
354
+ def __getitem__(self, key):
355
+ return getattr(self, key)
356
+
357
+ def __setitem__(self, key, value):
358
+ return setattr(self, key, value)
359
+
360
+ def __contains__(self, key):
361
+ return key in self.__dict__
362
+
363
+ def __repr__(self):
364
+ return self.__dict__.__repr__()
365
+
366
+
367
+ if __name__ == "__main__":
368
+ print(
369
+ load_wav_to_torch(
370
+ "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
371
+ )
372
+ )