Sprt98 commited on
Commit
7f12d4b
·
verified ·
1 Parent(s): 1881570

Delete config.py

Browse files
Files changed (1) hide show
  1. config.py +0 -261
config.py DELETED
@@ -1,261 +0,0 @@
1
- """
2
- @Desc: 全局配置文件读取
3
- """
4
-
5
- import argparse
6
- import yaml
7
- from typing import Dict, List
8
- import os
9
- import shutil
10
- import sys
11
-
12
-
13
- class Resample_config:
14
- """重采样配置"""
15
-
16
- def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
17
- self.sampling_rate: int = sampling_rate # 目标采样率
18
- self.in_dir: str = in_dir # 待处理音频目录路径
19
- self.out_dir: str = out_dir # 重采样输出路径
20
-
21
- @classmethod
22
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
23
- """从字典中生成实例"""
24
-
25
- # 不检查路径是否有效,此逻辑在resample.py中处理
26
- data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
27
- data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
28
-
29
- return cls(**data)
30
-
31
-
32
- class Preprocess_text_config:
33
- """数据预处理配置"""
34
-
35
- def __init__(
36
- self,
37
- transcription_path: str,
38
- cleaned_path: str,
39
- train_path: str,
40
- val_path: str,
41
- config_path: str,
42
- val_per_lang: int = 5,
43
- max_val_total: int = 10000,
44
- clean: bool = True,
45
- ):
46
- self.transcription_path: str = (
47
- transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
48
- )
49
- self.cleaned_path: str = (
50
- cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
51
- )
52
- self.train_path: str = (
53
- train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
54
- )
55
- self.val_path: str = (
56
- val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
57
- )
58
- self.config_path: str = config_path # 配置文件路径
59
- self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
60
- self.max_val_total: int = (
61
- max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
62
- )
63
- self.clean: bool = clean # 是否进行数据清洗
64
-
65
- @classmethod
66
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
67
- """从字典中生成实例"""
68
-
69
- data["transcription_path"] = os.path.join(
70
- dataset_path, data["transcription_path"]
71
- )
72
- if data["cleaned_path"] == "" or data["cleaned_path"] is None:
73
- data["cleaned_path"] = None
74
- else:
75
- data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
76
- data["train_path"] = os.path.join(dataset_path, data["train_path"])
77
- data["val_path"] = os.path.join(dataset_path, data["val_path"])
78
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
79
-
80
- return cls(**data)
81
-
82
-
83
- class Bert_gen_config:
84
- """bert_gen 配置"""
85
-
86
- def __init__(
87
- self,
88
- config_path: str,
89
- num_processes: int = 2,
90
- device: str = "cuda",
91
- use_multi_device: bool = False,
92
- ):
93
- self.config_path = config_path
94
- self.num_processes = num_processes
95
- self.device = device
96
- self.use_multi_device = use_multi_device
97
-
98
- @classmethod
99
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
100
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
101
-
102
- return cls(**data)
103
-
104
-
105
- class Emo_gen_config:
106
- """emo_gen 配置"""
107
-
108
- def __init__(
109
- self,
110
- config_path: str,
111
- num_processes: int = 2,
112
- device: str = "cuda",
113
- use_multi_device: bool = False,
114
- ):
115
- self.config_path = config_path
116
- self.num_processes = num_processes
117
- self.device = device
118
- self.use_multi_device = use_multi_device
119
-
120
- @classmethod
121
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
122
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
123
-
124
- return cls(**data)
125
-
126
-
127
- class Train_ms_config:
128
- """训练配置"""
129
-
130
- def __init__(
131
- self,
132
- config_path: str,
133
- env: Dict[str, any],
134
- base: Dict[str, any],
135
- model: str,
136
- num_workers: int,
137
- spec_cache: bool,
138
- keep_ckpts: int,
139
- ):
140
- self.env = env # 需要加载的环境变量
141
- self.base = base # 底模配置
142
- self.model = (
143
- model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
144
- )
145
- self.config_path = config_path # 配置文件路径
146
- self.num_workers = num_workers # worker数量
147
- self.spec_cache = spec_cache # 是否启用spec缓存
148
- self.keep_ckpts = keep_ckpts # ckpt数量
149
-
150
- @classmethod
151
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
152
- # data["model"] = os.path.join(dataset_path, data["model"])
153
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
154
-
155
- return cls(**data)
156
-
157
-
158
- class Webui_config:
159
- """webui 配置"""
160
-
161
- def __init__(
162
- self,
163
- device: str,
164
- model: str,
165
- config_path: str,
166
- language_identification_library: str,
167
- port: int = 7860,
168
- share: bool = False,
169
- debug: bool = False,
170
- ):
171
- self.device: str = device
172
- self.model: str = model # 端口号
173
- self.config_path: str = config_path # 是否公开部署,对外网开放
174
- self.port: int = port # 是否开启debug模式
175
- self.share: bool = share # 模型路径
176
- self.debug: bool = debug # 配置文件路径
177
- self.language_identification_library: str = (
178
- language_identification_library # 语种识别库
179
- )
180
-
181
- @classmethod
182
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
183
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
184
- data["model"] = os.path.join(dataset_path, data["model"])
185
- return cls(**data)
186
-
187
-
188
- class Server_config:
189
- def __init__(
190
- self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda"
191
- ):
192
- self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置
193
- self.port: int = port # 端口号
194
- self.device: str = device # 模型默认使用设备
195
-
196
- @classmethod
197
- def from_dict(cls, data: Dict[str, any]):
198
- return cls(**data)
199
-
200
-
201
- class Translate_config:
202
- """翻译api配置"""
203
-
204
- def __init__(self, app_key: str, secret_key: str):
205
- self.app_key = app_key
206
- self.secret_key = secret_key
207
-
208
- @classmethod
209
- def from_dict(cls, data: Dict[str, any]):
210
- return cls(**data)
211
-
212
-
213
- class Config:
214
- def __init__(self, config_path: str):
215
- if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
216
- shutil.copy(src="default_config.yml", dst=config_path)
217
- print(
218
- f"已根据默认配置文件default_config.yml生成配置文件{config_path}。请按该配置文件的说明进行配置后重新运行。"
219
- )
220
- print("如无特殊需求,请勿修改default_config.yml或备份该文件。")
221
- sys.exit(0)
222
- with open(file=config_path, mode="r", encoding="utf-8") as file:
223
- yaml_config: Dict[str, any] = yaml.safe_load(file.read())
224
- dataset_path: str = yaml_config["dataset_path"]
225
- openi_token: str = yaml_config["openi_token"]
226
- self.dataset_path: str = dataset_path
227
- self.mirror: str = yaml_config["mirror"]
228
- self.openi_token: str = openi_token
229
- self.resample_config: Resample_config = Resample_config.from_dict(
230
- dataset_path, yaml_config["resample"]
231
- )
232
- self.preprocess_text_config: Preprocess_text_config = (
233
- Preprocess_text_config.from_dict(
234
- dataset_path, yaml_config["preprocess_text"]
235
- )
236
- )
237
- self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
238
- dataset_path, yaml_config["bert_gen"]
239
- )
240
- self.emo_gen_config: Emo_gen_config = Emo_gen_config.from_dict(
241
- dataset_path, yaml_config["emo_gen"]
242
- )
243
- self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
244
- dataset_path, yaml_config["train_ms"]
245
- )
246
- self.webui_config: Webui_config = Webui_config.from_dict(
247
- dataset_path, yaml_config["webui"]
248
- )
249
- self.server_config: Server_config = Server_config.from_dict(
250
- yaml_config["server"]
251
- )
252
- self.translate_config: Translate_config = Translate_config.from_dict(
253
- yaml_config["translate"]
254
- )
255
-
256
-
257
- parser = argparse.ArgumentParser()
258
- # 为避免与以前的config.json起冲突,将其更名如下
259
- parser.add_argument("-y", "--yml_config", type=str, default="config.yml")
260
- args, _ = parser.parse_known_args()
261
- config = Config(args.yml_config)