Spaces:
Runtime error
Runtime error
Delete CHATTS/core.py
Browse files- CHATTS/core.py +0 -149
CHATTS/core.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
import logging
|
4 |
-
from omegaconf import OmegaConf
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from vocos import Vocos
|
8 |
-
from .model.dvae import DVAE
|
9 |
-
from .model.gpt import GPT_warpper
|
10 |
-
from .utils.gpu_utils import select_device
|
11 |
-
from .utils.io_utils import get_latest_modified_file
|
12 |
-
from .infer.api import refine_text, infer_code
|
13 |
-
|
14 |
-
from huggingface_hub import snapshot_download
|
15 |
-
|
16 |
-
logging.basicConfig(level = logging.INFO)
|
17 |
-
|
18 |
-
|
19 |
-
class Chat:
|
20 |
-
def __init__(self, ):
|
21 |
-
self.pretrain_models = {}
|
22 |
-
self.logger = logging.getLogger(__name__)
|
23 |
-
|
24 |
-
def check_model(self, level = logging.INFO, use_decoder = False):
|
25 |
-
not_finish = False
|
26 |
-
check_list = ['vocos', 'gpt', 'tokenizer']
|
27 |
-
|
28 |
-
if use_decoder:
|
29 |
-
check_list.append('decoder')
|
30 |
-
else:
|
31 |
-
check_list.append('dvae')
|
32 |
-
|
33 |
-
for module in check_list:
|
34 |
-
if module not in self.pretrain_models:
|
35 |
-
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
36 |
-
not_finish = True
|
37 |
-
|
38 |
-
if not not_finish:
|
39 |
-
self.logger.log(level, f'All initialized.')
|
40 |
-
|
41 |
-
return not not_finish
|
42 |
-
|
43 |
-
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
|
44 |
-
if source == 'huggingface':
|
45 |
-
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
46 |
-
try:
|
47 |
-
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
48 |
-
except:
|
49 |
-
download_path = None
|
50 |
-
if download_path is None or force_redownload:
|
51 |
-
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
52 |
-
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
53 |
-
else:
|
54 |
-
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
55 |
-
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
|
56 |
-
elif source == 'local':
|
57 |
-
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
58 |
-
self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
|
59 |
-
|
60 |
-
def _load(
|
61 |
-
self,
|
62 |
-
vocos_config_path: str = None,
|
63 |
-
vocos_ckpt_path: str = None,
|
64 |
-
dvae_config_path: str = None,
|
65 |
-
dvae_ckpt_path: str = None,
|
66 |
-
gpt_config_path: str = None,
|
67 |
-
gpt_ckpt_path: str = None,
|
68 |
-
decoder_config_path: str = None,
|
69 |
-
decoder_ckpt_path: str = None,
|
70 |
-
tokenizer_path: str = None,
|
71 |
-
device: str = None
|
72 |
-
):
|
73 |
-
if not device:
|
74 |
-
device = select_device(4096)
|
75 |
-
self.logger.log(logging.INFO, f'use {device}')
|
76 |
-
|
77 |
-
if vocos_config_path:
|
78 |
-
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
79 |
-
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
80 |
-
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
81 |
-
self.pretrain_models['vocos'] = vocos
|
82 |
-
self.logger.log(logging.INFO, 'vocos loaded.')
|
83 |
-
|
84 |
-
if dvae_config_path:
|
85 |
-
cfg = OmegaConf.load(dvae_config_path)
|
86 |
-
dvae = DVAE(**cfg).to(device).eval()
|
87 |
-
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
|
88 |
-
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
89 |
-
self.pretrain_models['dvae'] = dvae
|
90 |
-
self.logger.log(logging.INFO, 'dvae loaded.')
|
91 |
-
|
92 |
-
if gpt_config_path:
|
93 |
-
cfg = OmegaConf.load(gpt_config_path)
|
94 |
-
gpt = GPT_warpper(**cfg).to(device).eval()
|
95 |
-
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
96 |
-
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
97 |
-
self.pretrain_models['gpt'] = gpt
|
98 |
-
self.logger.log(logging.INFO, 'gpt loaded.')
|
99 |
-
|
100 |
-
if decoder_config_path:
|
101 |
-
cfg = OmegaConf.load(decoder_config_path)
|
102 |
-
decoder = DVAE(**cfg).to(device).eval()
|
103 |
-
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
|
104 |
-
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
105 |
-
self.pretrain_models['decoder'] = decoder
|
106 |
-
self.logger.log(logging.INFO, 'decoder loaded.')
|
107 |
-
|
108 |
-
if tokenizer_path:
|
109 |
-
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
110 |
-
tokenizer.padding_side = 'left'
|
111 |
-
self.pretrain_models['tokenizer'] = tokenizer
|
112 |
-
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
113 |
-
|
114 |
-
self.check_model()
|
115 |
-
|
116 |
-
def infer(
|
117 |
-
self,
|
118 |
-
text,
|
119 |
-
skip_refine_text=False,
|
120 |
-
refine_text_only=False,
|
121 |
-
params_refine_text={},
|
122 |
-
params_infer_code={},
|
123 |
-
use_decoder=False
|
124 |
-
):
|
125 |
-
|
126 |
-
assert self.check_model(use_decoder=use_decoder)
|
127 |
-
|
128 |
-
if not skip_refine_text:
|
129 |
-
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
130 |
-
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
131 |
-
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
132 |
-
if refine_text_only:
|
133 |
-
return text
|
134 |
-
|
135 |
-
text = [params_infer_code.get('prompt', '') + i for i in text]
|
136 |
-
params_infer_code.pop('prompt', '')
|
137 |
-
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
138 |
-
|
139 |
-
if use_decoder:
|
140 |
-
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
141 |
-
else:
|
142 |
-
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
143 |
-
|
144 |
-
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
145 |
-
|
146 |
-
return wav
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|