Hemant0000 commited on
Commit
b030399
·
verified ·
1 Parent(s): fc8d5c8

Create socket_server.py

Browse files
Files changed (1) hide show
  1. src/f5_tts/socket_server.py +159 -0
src/f5_tts/socket_server.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import struct
3
+ import torch
4
+ import torchaudio
5
+ from threading import Thread
6
+
7
+
8
+ import gc
9
+ import traceback
10
+
11
+
12
+ from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
+ from model.backbones.dit import DiT
14
+
15
+
16
+ class TTSStreamingProcessor:
17
+ def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load the model using the provided checkpoint and vocab files
21
+ self.model = load_model(
22
+ model_cls=DiT,
23
+ model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24
+ ckpt_path=ckpt_file,
25
+ mel_spec_type="vocos", # or "bigvgan" depending on vocoder
26
+ vocab_file=vocab_file,
27
+ ode_method="euler",
28
+ use_ema=True,
29
+ device=self.device,
30
+ ).to(self.device, dtype=dtype)
31
+
32
+ # Load the vocoder
33
+ self.vocoder = load_vocoder(is_local=False)
34
+
35
+ # Set sampling rate for streaming
36
+ self.sampling_rate = 24000 # Consistency with client
37
+
38
+ # Set reference audio and text
39
+ self.ref_audio = ref_audio
40
+ self.ref_text = ref_text
41
+
42
+ # Warm up the model
43
+ self._warm_up()
44
+
45
+ def _warm_up(self):
46
+ """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
47
+ print("Warming up the model...")
48
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
49
+ audio, sr = torchaudio.load(ref_audio)
50
+ gen_text = "Warm-up text for the model."
51
+
52
+ # Pass the vocoder as an argument here
53
+ infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
54
+ print("Warm-up completed.")
55
+
56
+ def generate_stream(self, text, play_steps_in_s=0.5):
57
+ """Generate audio in chunks and yield them in real-time."""
58
+ # Preprocess the reference audio and text
59
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
60
+
61
+ # Load reference audio
62
+ audio, sr = torchaudio.load(ref_audio)
63
+
64
+ # Run inference for the input text
65
+ audio_chunk, final_sample_rate, _ = infer_batch_process(
66
+ (audio, sr),
67
+ ref_text,
68
+ [text],
69
+ self.model,
70
+ self.vocoder,
71
+ device=self.device, # Pass vocoder here
72
+ )
73
+
74
+ # Break the generated audio into chunks and send them
75
+ chunk_size = int(final_sample_rate * play_steps_in_s)
76
+
77
+ if len(audio_chunk) < chunk_size:
78
+ packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk)
79
+ yield packed_audio
80
+ return
81
+
82
+ for i in range(0, len(audio_chunk), chunk_size):
83
+ chunk = audio_chunk[i : i + chunk_size]
84
+
85
+ # Check if it's the final chunk
86
+ if i + chunk_size >= len(audio_chunk):
87
+ chunk = audio_chunk[i:]
88
+
89
+ # Send the chunk if it is not empty
90
+ if len(chunk) > 0:
91
+ packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
92
+ yield packed_audio
93
+
94
+
95
+ def handle_client(client_socket, processor):
96
+ try:
97
+ while True:
98
+ # Receive data from the client
99
+ data = client_socket.recv(1024).decode("utf-8")
100
+ if not data:
101
+ break
102
+
103
+ try:
104
+ # The client sends the text input
105
+ text = data.strip()
106
+
107
+ # Generate and stream audio chunks
108
+ for audio_chunk in processor.generate_stream(text):
109
+ client_socket.sendall(audio_chunk)
110
+
111
+ # Send end-of-audio signal
112
+ client_socket.sendall(b"END_OF_AUDIO")
113
+
114
+ except Exception as inner_e:
115
+ print(f"Error during processing: {inner_e}")
116
+ traceback.print_exc() # Print the full traceback to diagnose the issue
117
+ break
118
+
119
+ except Exception as e:
120
+ print(f"Error handling client: {e}")
121
+ traceback.print_exc()
122
+ finally:
123
+ client_socket.close()
124
+
125
+
126
+ def start_server(host, port, processor):
127
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128
+ server.bind((host, port))
129
+ server.listen(5)
130
+ print(f"Server listening on {host}:{port}")
131
+
132
+ while True:
133
+ client_socket, addr = server.accept()
134
+ print(f"Accepted connection from {addr}")
135
+ client_handler = Thread(target=handle_client, args=(client_socket, processor))
136
+ client_handler.start()
137
+
138
+
139
+ if __name__ == "__main__":
140
+ try:
141
+ # Load the model and vocoder using the provided files
142
+ ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
143
+ vocab_file = "" # Add vocab file path if needed
144
+ ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
145
+ ref_text = ""
146
+
147
+ # Initialize the processor with the model and vocoder
148
+ processor = TTSStreamingProcessor(
149
+ ckpt_file=ckpt_file,
150
+ vocab_file=vocab_file,
151
+ ref_audio=ref_audio,
152
+ ref_text=ref_text,
153
+ dtype=torch.float32,
154
+ )
155
+
156
+ # Start the server
157
+ start_server("0.0.0.0", 9998, processor)
158
+ except KeyboardInterrupt:
159
+ gc.collect()