File size: 5,367 Bytes
2097ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import socket
import struct
import torch
import torchaudio
from threading import Thread
import gc
import traceback
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
from model.backbones.dit import DiT
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Load the model using the provided checkpoint and vocab files
self.model = load_model(
DiT,
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
ckpt_file,
vocab_file,
).to(self.device, dtype=dtype)
# Load the vocoder
self.vocoder = load_vocoder(is_local=False)
# Set sampling rate for streaming
self.sampling_rate = 24000 # Consistency with client
# Set reference audio and text
self.ref_audio = ref_audio
self.ref_text = ref_text
# Warm up the model
self._warm_up()
def _warm_up(self):
"""Warm up the model with a dummy input to ensure it's ready for real-time processing."""
print("Warming up the model...")
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
audio, sr = torchaudio.load(ref_audio)
gen_text = "Warm-up text for the model."
# Pass the vocoder as an argument here
infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
print("Warm-up completed.")
def generate_stream(self, text, play_steps_in_s=0.5):
"""Generate audio in chunks and yield them in real-time."""
# Preprocess the reference audio and text
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
# Load reference audio
audio, sr = torchaudio.load(ref_audio)
# Run inference for the input text
audio_chunk, final_sample_rate, _ = infer_batch_process(
(audio, sr),
ref_text,
[text],
self.model,
self.vocoder,
device=self.device, # Pass vocoder here
)
# Break the generated audio into chunks and send them
chunk_size = int(final_sample_rate * play_steps_in_s)
for i in range(0, len(audio_chunk), chunk_size):
chunk = audio_chunk[i : i + chunk_size]
# Check if it's the final chunk
if i + chunk_size >= len(audio_chunk):
chunk = audio_chunk[i:]
# Avoid sending empty or repeated chunks
if len(chunk) == 0:
break
# Pack and send the audio chunk
packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
yield packed_audio
# Ensure that no final word is repeated by not resending partial chunks
if len(audio_chunk) % chunk_size != 0:
remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size) :]
packed_audio = struct.pack(f"{len(remaining_chunk)}f", *remaining_chunk)
yield packed_audio
def handle_client(client_socket, processor):
try:
while True:
# Receive data from the client
data = client_socket.recv(1024).decode("utf-8")
if not data:
break
try:
# The client sends the text input
text = data.strip()
# Generate and stream audio chunks
for audio_chunk in processor.generate_stream(text):
client_socket.sendall(audio_chunk)
# Send end-of-audio signal
client_socket.sendall(b"END_OF_AUDIO")
except Exception as inner_e:
print(f"Error during processing: {inner_e}")
traceback.print_exc() # Print the full traceback to diagnose the issue
break
except Exception as e:
print(f"Error handling client: {e}")
traceback.print_exc()
finally:
client_socket.close()
def start_server(host, port, processor):
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind((host, port))
server.listen(5)
print(f"Server listening on {host}:{port}")
while True:
client_socket, addr = server.accept()
print(f"Accepted connection from {addr}")
client_handler = Thread(target=handle_client, args=(client_socket, processor))
client_handler.start()
if __name__ == "__main__":
try:
# Load the model and vocoder using the provided files
ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
vocab_file = "" # Add vocab file path if needed
ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
ref_text = ""
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
ckpt_file=ckpt_file,
vocab_file=vocab_file,
ref_audio=ref_audio,
ref_text=ref_text,
dtype=torch.float32,
)
# Start the server
start_server("0.0.0.0", 9998, processor)
except KeyboardInterrupt:
gc.collect()
|