Pijush2023 commited on
Commit
dacfa9a
·
verified ·
1 Parent(s): ff05a77

Delete CHATTS/core.py

Browse files
Files changed (1) hide show
  1. CHATTS/core.py +0 -149
CHATTS/core.py DELETED
@@ -1,149 +0,0 @@
1
-
2
- import os
3
- import logging
4
- from omegaconf import OmegaConf
5
-
6
- import torch
7
- from vocos import Vocos
8
- from .model.dvae import DVAE
9
- from .model.gpt import GPT_warpper
10
- from .utils.gpu_utils import select_device
11
- from .utils.io_utils import get_latest_modified_file
12
- from .infer.api import refine_text, infer_code
13
-
14
- from huggingface_hub import snapshot_download
15
-
16
- logging.basicConfig(level = logging.INFO)
17
-
18
-
19
- class Chat:
20
- def __init__(self, ):
21
- self.pretrain_models = {}
22
- self.logger = logging.getLogger(__name__)
23
-
24
- def check_model(self, level = logging.INFO, use_decoder = False):
25
- not_finish = False
26
- check_list = ['vocos', 'gpt', 'tokenizer']
27
-
28
- if use_decoder:
29
- check_list.append('decoder')
30
- else:
31
- check_list.append('dvae')
32
-
33
- for module in check_list:
34
- if module not in self.pretrain_models:
35
- self.logger.log(logging.WARNING, f'{module} not initialized.')
36
- not_finish = True
37
-
38
- if not not_finish:
39
- self.logger.log(level, f'All initialized.')
40
-
41
- return not not_finish
42
-
43
- def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
44
- if source == 'huggingface':
45
- hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
46
- try:
47
- download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
48
- except:
49
- download_path = None
50
- if download_path is None or force_redownload:
51
- self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
52
- download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
53
- else:
54
- self.logger.log(logging.INFO, f'Load from cache: {download_path}')
55
- self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
56
- elif source == 'local':
57
- self.logger.log(logging.INFO, f'Load from local: {local_path}')
58
- self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
59
-
60
- def _load(
61
- self,
62
- vocos_config_path: str = None,
63
- vocos_ckpt_path: str = None,
64
- dvae_config_path: str = None,
65
- dvae_ckpt_path: str = None,
66
- gpt_config_path: str = None,
67
- gpt_ckpt_path: str = None,
68
- decoder_config_path: str = None,
69
- decoder_ckpt_path: str = None,
70
- tokenizer_path: str = None,
71
- device: str = None
72
- ):
73
- if not device:
74
- device = select_device(4096)
75
- self.logger.log(logging.INFO, f'use {device}')
76
-
77
- if vocos_config_path:
78
- vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
79
- assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
80
- vocos.load_state_dict(torch.load(vocos_ckpt_path))
81
- self.pretrain_models['vocos'] = vocos
82
- self.logger.log(logging.INFO, 'vocos loaded.')
83
-
84
- if dvae_config_path:
85
- cfg = OmegaConf.load(dvae_config_path)
86
- dvae = DVAE(**cfg).to(device).eval()
87
- assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
88
- dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
89
- self.pretrain_models['dvae'] = dvae
90
- self.logger.log(logging.INFO, 'dvae loaded.')
91
-
92
- if gpt_config_path:
93
- cfg = OmegaConf.load(gpt_config_path)
94
- gpt = GPT_warpper(**cfg).to(device).eval()
95
- assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
96
- gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
97
- self.pretrain_models['gpt'] = gpt
98
- self.logger.log(logging.INFO, 'gpt loaded.')
99
-
100
- if decoder_config_path:
101
- cfg = OmegaConf.load(decoder_config_path)
102
- decoder = DVAE(**cfg).to(device).eval()
103
- assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
104
- decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
105
- self.pretrain_models['decoder'] = decoder
106
- self.logger.log(logging.INFO, 'decoder loaded.')
107
-
108
- if tokenizer_path:
109
- tokenizer = torch.load(tokenizer_path, map_location='cpu')
110
- tokenizer.padding_side = 'left'
111
- self.pretrain_models['tokenizer'] = tokenizer
112
- self.logger.log(logging.INFO, 'tokenizer loaded.')
113
-
114
- self.check_model()
115
-
116
- def infer(
117
- self,
118
- text,
119
- skip_refine_text=False,
120
- refine_text_only=False,
121
- params_refine_text={},
122
- params_infer_code={},
123
- use_decoder=False
124
- ):
125
-
126
- assert self.check_model(use_decoder=use_decoder)
127
-
128
- if not skip_refine_text:
129
- text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
130
- text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
131
- text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
132
- if refine_text_only:
133
- return text
134
-
135
- text = [params_infer_code.get('prompt', '') + i for i in text]
136
- params_infer_code.pop('prompt', '')
137
- result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
138
-
139
- if use_decoder:
140
- mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
141
- else:
142
- mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
143
-
144
- wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
145
-
146
- return wav
147
-
148
-
149
-