File size: 2,256 Bytes
d2b7e94
 
8f52106
d2b7e94
01e655b
f367757
d2b7e94
01e655b
d2b7e94
02e90e4
01e655b
 
8f52106
01e655b
8a3a4ec
01e655b
 
8f52106
01e655b
 
8f52106
02e90e4
650b56c
01e655b
bed01bd
 
01e655b
02e90e4
01e655b
 
bed01bd
 
02e90e4
 
 
 
01e655b
 
bed01bd
 
 
 
 
 
 
02e90e4
650b56c
8f52106
 
da8d589
8a3a4ec
 
6ecb8c2
374f426
 
01e655b
02e90e4
 
8f52106
 
02e90e4
8f52106
02e90e4
8f52106
 
 
 
02e90e4
627d3d7
 
8f52106
 
 
 
 
 
 
 
 
f367757
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gc
import logging
import threading

import torch
from transformers import LlamaTokenizer

from modules import config
from modules.ChatTTS import ChatTTS
from modules.devices import devices

logger = logging.getLogger(__name__)

chat_tts = None
lock = threading.Lock()


def load_chat_tts_in_thread():
    global chat_tts
    if chat_tts:
        return

    logger.info("Loading ChatTTS models")
    chat_tts = ChatTTS.Chat()
    device = devices.get_device_for("chattts")
    dtype = devices.dtype
    chat_tts.load_models(
        compile=config.runtime_env_vars.compile,
        source="local",
        local_path="./models/ChatTTS",
        device=device,
        dtype=dtype,
        dtype_vocos=devices.dtype_vocos,
        dtype_dvae=devices.dtype_dvae,
        dtype_gpt=devices.dtype_gpt,
        dtype_decoder=devices.dtype_decoder,
    )

    # 如果 device 为 cpu 同时,又是 dtype == float16 那么报 warn
    # 提示可能无法正常运行,建议使用 float32 即开启 `--no_half` 参数
    if device == devices.cpu and dtype == torch.float16:
        logger.warning(
            "The device is CPU and dtype is float16, which may not work properly. It is recommended to use float32 by enabling the `--no_half` parameter."
        )

    devices.torch_gc()
    logger.info("ChatTTS models loaded")


def load_chat_tts():
    with lock:
        if chat_tts is None:
            load_chat_tts_in_thread()
    if chat_tts is None:
        raise Exception("Failed to load ChatTTS models")
    return chat_tts


def unload_chat_tts():
    logging.info("Unloading ChatTTS models")
    global chat_tts

    if chat_tts:
        for model_name, model in chat_tts.pretrain_models.items():
            if isinstance(model, torch.nn.Module):
                model.cpu()
                del model
    chat_tts = None
    devices.torch_gc()
    gc.collect()
    logger.info("ChatTTS models unloaded")


def reload_chat_tts():
    logging.info("Reloading ChatTTS models")
    unload_chat_tts()
    instance = load_chat_tts()
    logger.info("ChatTTS models reloaded")
    return instance


def get_tokenizer() -> LlamaTokenizer:
    chat_tts = load_chat_tts()
    tokenizer = chat_tts.pretrain_models["tokenizer"]
    return tokenizer