|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append('{}/../..'.format(ROOT_DIR)) |
|
sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR)) |
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) |
|
from concurrent import futures |
|
import argparse |
|
import cosyvoice_pb2 |
|
import cosyvoice_pb2_grpc |
|
import logging |
|
logging.getLogger('matplotlib').setLevel(logging.WARNING) |
|
import grpc |
|
import torch |
|
import numpy as np |
|
from cosyvoice.cli.cosyvoice import CosyVoice |
|
|
|
logging.basicConfig(level=logging.DEBUG, |
|
format='%(asctime)s %(levelname)s %(message)s') |
|
|
|
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): |
|
def __init__(self, args): |
|
self.cosyvoice = CosyVoice(args.model_dir) |
|
logging.info('grpc service initialized') |
|
|
|
def Inference(self, request, context): |
|
if request.HasField('sft_request'): |
|
logging.info('get sft inference request') |
|
model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) |
|
elif request.HasField('zero_shot_request'): |
|
logging.info('get zero_shot inference request') |
|
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
|
prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
|
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k) |
|
elif request.HasField('cross_lingual_request'): |
|
logging.info('get cross_lingual inference request') |
|
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
|
prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
|
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) |
|
else: |
|
logging.info('get instruct inference request') |
|
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text) |
|
|
|
logging.info('send inference response') |
|
response = cosyvoice_pb2.Response() |
|
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() |
|
return response |
|
|
|
def main(): |
|
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) |
|
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) |
|
grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) |
|
grpcServer.start() |
|
logging.info("server listening on 0.0.0.0:{}".format(args.port)) |
|
grpcServer.wait_for_termination() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--port', |
|
type=int, |
|
default=50000) |
|
parser.add_argument('--max_conc', |
|
type=int, |
|
default=4) |
|
parser.add_argument('--model_dir', |
|
type=str, |
|
required=True, |
|
default='speech_tts/CosyVoice-300M', |
|
help='local path or modelscope repo id') |
|
args = parser.parse_args() |
|
main() |
|
|