Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,128 Bytes
8f52106 01e655b 02e90e4 01e655b 02e90e4 01e655b 8f52106 01e655b 8f52106 01e655b 650b56c 8f52106 01e655b 8f52106 01e655b 650b56c 8f52106 02e90e4 650b56c 01e655b 02e90e4 01e655b 02e90e4 01e655b 02e90e4 650b56c 8f52106 650b56c 01e655b 8f52106 650b56c 374f426 01e655b 02e90e4 8f52106 02e90e4 8f52106 02e90e4 8f52106 02e90e4 8f52106 02e90e4 8f52106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import threading
import torch
from modules.ChatTTS import ChatTTS
from modules import config
from modules.devices import devices
import logging
import gc
logger = logging.getLogger(__name__)
chat_tts = None
# 某些平台上,不让在主线程中加载模型,否则会出现错误
# huggingface Error:
# RuntimeError: CUDA must not be initialized in the main process on Spaces with Stateless GPU environment.
# You can look at this Stacktrace to find out which part of your code triggered a CUDA init
load_event = threading.Event()
def load_chat_tts_in_thread():
global chat_tts
if chat_tts:
load_event.set()
return
logger.info("Loading ChatTTS models")
chat_tts = ChatTTS.Chat()
chat_tts.load_models(
compile=config.runtime_env_vars.compile,
source="local",
local_path="./models/ChatTTS",
device=devices.device,
dtype=devices.dtype,
dtype_vocos=devices.dtype_vocos,
dtype_dvae=devices.dtype_dvae,
dtype_gpt=devices.dtype_gpt,
dtype_decoder=devices.dtype_decoder,
)
devices.torch_gc()
load_event.set()
logger.info("ChatTTS models loaded")
def initialize_chat_tts():
model_thread = threading.Thread(target=load_chat_tts_in_thread)
model_thread.start()
return model_thread
def load_chat_tts():
if chat_tts is None:
initialize_chat_tts().join()
if chat_tts is None:
raise Exception("Failed to load ChatTTS models")
return chat_tts
def unload_chat_tts():
logging.info("Unloading ChatTTS models")
global chat_tts
if chat_tts:
for model_name, model in chat_tts.pretrain_models.items():
if isinstance(model, torch.nn.Module):
model.cpu()
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
chat_tts = None
logger.info("ChatTTS models unloaded")
def reload_chat_tts():
logging.info("Reloading ChatTTS models")
unload_chat_tts()
instance = load_chat_tts()
logger.info("ChatTTS models reloaded")
return instance
|