File size: 6,146 Bytes
7def60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
"""
Extra gRPC server for OpenVoice models.
"""
from concurrent import futures

import argparse
import signal
import sys
import os
import torch
from openvoice import se_extractor
from openvoice.api import ToneColorConverter
from melo.api import TTS

import time
import backend_pb2
import backend_pb2_grpc

import grpc


_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
    """
    A gRPC servicer for the backend service.

    This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
    """
    def Health(self, request, context):
        """
        A gRPC method that returns the health status of the backend service.

        Args:
            request: A HealthRequest object that contains the request parameters.
            context: A grpc.ServicerContext object that provides information about the RPC.

        Returns:
            A Reply object that contains the health status of the backend service.
        """
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

    def LoadModel(self, request, context):
        """
        A gRPC method that loads a model into memory.

        Args:
            request: A LoadModelRequest object that contains the request parameters.
            context: A grpc.ServicerContext object that provides information about the RPC.

        Returns:
            A Result object that contains the result of the LoadModel operation.
        """
        model_name = request.Model
        try:

            self.clonedVoice = False
            # Assume directory from request.ModelFile.
            # Only if request.LoraAdapter it's not an absolute path
            if request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
                # get base path of modelFile
                modelFileBase = os.path.dirname(request.ModelFile)
                request.AudioPath = os.path.join(modelFileBase, request.AudioPath)
            if request.AudioPath != "":
                self.clonedVoice = True

            self.modelpath = request.ModelFile
            self.speaker = request.Type
            self.ClonedVoicePath = request.AudioPath
            
            ckpt_converter = request.Model+'/converter'
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
            self.device = device
            self.tone_color_converter = None
            if self.clonedVoice:
                self.tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
                self.tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
       
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")

        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def TTS(self, request, context):
        model_name = request.model
        if model_name == "":
            return backend_pb2.Result(success=False, message="request.model is required")
        try:
            # Speed is adjustable
            speed = 1.0
            voice = "EN"
            if request.voice:
                voice = request.voice
            model = TTS(language=voice, device=self.device)
            speaker_ids = model.hps.data.spk2id
            speaker_key = self.speaker
            modelpath = self.modelpath
            for s in speaker_ids.keys():
                print(f"Speaker: {s} - ID: {speaker_ids[s]}")
            speaker_id = speaker_ids[speaker_key]
            speaker_key = speaker_key.lower().replace('_', '-')
            source_se = torch.load(f'{modelpath}/base_speakers/ses/{speaker_key}.pth', map_location=self.device)
            model.tts_to_file(request.text, speaker_id, request.dst, speed=speed)
            if self.clonedVoice:
                reference_speaker = self.ClonedVoicePath
                target_se, audio_name = se_extractor.get_se(reference_speaker, self.tone_color_converter, vad=False)
                # Run the tone color converter
                encode_message = "@MyShell"
                self.tone_color_converter.convert(
                    audio_src_path=request.dst, 
                    src_se=source_se, 
                    tgt_se=target_se, 
                    output_path=request.dst,
                    message=encode_message)
           
            print("[OpenVoice] TTS generated!", file=sys.stderr)
            print("[OpenVoice] TTS saved to", request.dst, file=sys.stderr)
            print(request, file=sys.stderr)
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        return backend_pb2.Result(success=True)

def serve(address):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)
    server.start()
    print("[OpenVoice] Server started. Listening on: " + address, file=sys.stderr)

    # Define the signal handler function
    def signal_handler(sig, frame):
        print("[OpenVoice] Received termination signal. Shutting down...")
        server.stop(0)
        sys.exit(0)

    # Set the signal handlers for SIGINT and SIGTERM
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the gRPC server.")
    parser.add_argument(
        "--addr", default="localhost:50051", help="The address to bind the server to."
    )
    args = parser.parse_args()
    print(f"[OpenVoice] startup: {args}", file=sys.stderr)
    serve(args.addr)