Gregniuki commited on
Commit
2097ba5
1 Parent(s): cb028f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. f5-tts/api.py +151 -0
  2. f5-tts/socket.py +159 -0
f5-tts/api.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from importlib.resources import files
4
+
5
+ import soundfile as sf
6
+ import torch
7
+ import tqdm
8
+ from cached_path import cached_path
9
+
10
+ from f5_tts.infer.utils_infer import (
11
+ hop_length,
12
+ infer_process,
13
+ load_model,
14
+ load_vocoder,
15
+ preprocess_ref_audio_text,
16
+ remove_silence_for_generated_wav,
17
+ save_spectrogram,
18
+ target_sample_rate,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+ from f5_tts.model.utils import seed_everything
22
+
23
+
24
+ class F5TTS:
25
+ def __init__(
26
+ self,
27
+ model_type="F5-TTS",
28
+ ckpt_file="",
29
+ vocab_file="",
30
+ ode_method="euler",
31
+ use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
+ device=None,
35
+ ):
36
+ # Initialize parameters
37
+ self.final_wave = None
38
+ self.target_sample_rate = target_sample_rate
39
+ self.hop_length = hop_length
40
+ self.seed = -1
41
+ self.mel_spec_type = vocoder_name
42
+
43
+ # Set device
44
+ self.device = device or (
45
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
46
+ )
47
+
48
+ # Load models
49
+ self.load_vocoder_model(vocoder_name, local_path)
50
+ self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
+
52
+ def load_vocoder_model(self, vocoder_name, local_path):
53
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
+
55
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
56
+ if model_type == "F5-TTS":
57
+ if not ckpt_file:
58
+ if mel_spec_type == "vocos":
59
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
60
+ elif mel_spec_type == "bigvgan":
61
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
62
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
+ model_cls = DiT
64
+ elif model_type == "E2-TTS":
65
+ if not ckpt_file:
66
+ ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
+ model_cls = UNetT
69
+ else:
70
+ raise ValueError(f"Unknown model type: {model_type}")
71
+
72
+ self.ema_model = load_model(
73
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
74
+ )
75
+
76
+ def export_wav(self, wav, file_wave, remove_silence=False):
77
+ sf.write(file_wave, wav, self.target_sample_rate)
78
+
79
+ if remove_silence:
80
+ remove_silence_for_generated_wav(file_wave)
81
+
82
+ def export_spectrogram(self, spect, file_spect):
83
+ save_spectrogram(spect, file_spect)
84
+
85
+ def infer(
86
+ self,
87
+ ref_file,
88
+ ref_text,
89
+ gen_text,
90
+ show_info=print,
91
+ progress=tqdm,
92
+ target_rms=0.1,
93
+ cross_fade_duration=0.15,
94
+ sway_sampling_coef=-1,
95
+ cfg_strength=2,
96
+ nfe_step=32,
97
+ speed=1.0,
98
+ fix_duration=None,
99
+ remove_silence=False,
100
+ file_wave=None,
101
+ file_spect=None,
102
+ seed=-1,
103
+ ):
104
+ if seed == -1:
105
+ seed = random.randint(0, sys.maxsize)
106
+ seed_everything(seed)
107
+ self.seed = seed
108
+
109
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
110
+
111
+ wav, sr, spect = infer_process(
112
+ ref_file,
113
+ ref_text,
114
+ gen_text,
115
+ self.ema_model,
116
+ self.vocoder,
117
+ self.mel_spec_type,
118
+ show_info=show_info,
119
+ progress=progress,
120
+ target_rms=target_rms,
121
+ cross_fade_duration=cross_fade_duration,
122
+ nfe_step=nfe_step,
123
+ cfg_strength=cfg_strength,
124
+ sway_sampling_coef=sway_sampling_coef,
125
+ speed=speed,
126
+ fix_duration=fix_duration,
127
+ device=self.device,
128
+ )
129
+
130
+ if file_wave is not None:
131
+ self.export_wav(wav, file_wave, remove_silence)
132
+
133
+ if file_spect is not None:
134
+ self.export_spectrogram(spect, file_spect)
135
+
136
+ return wav, sr, spect
137
+
138
+
139
+ if __name__ == "__main__":
140
+ f5tts = F5TTS()
141
+
142
+ wav, sr, spect = f5tts.infer(
143
+ ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
144
+ ref_text="some call me nature, others call me mother nature.",
145
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
146
+ file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
147
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
148
+ seed=-1, # random seed = -1
149
+ )
150
+
151
+ print("seed :", f5tts.seed)
f5-tts/socket.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import struct
3
+ import torch
4
+ import torchaudio
5
+ from threading import Thread
6
+
7
+
8
+ import gc
9
+ import traceback
10
+
11
+
12
+ from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
+ from model.backbones.dit import DiT
14
+
15
+
16
+ class TTSStreamingProcessor:
17
+ def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load the model using the provided checkpoint and vocab files
21
+ self.model = load_model(
22
+ DiT,
23
+ dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24
+ ckpt_file,
25
+ vocab_file,
26
+ ).to(self.device, dtype=dtype)
27
+
28
+ # Load the vocoder
29
+ self.vocoder = load_vocoder(is_local=False)
30
+
31
+ # Set sampling rate for streaming
32
+ self.sampling_rate = 24000 # Consistency with client
33
+
34
+ # Set reference audio and text
35
+ self.ref_audio = ref_audio
36
+ self.ref_text = ref_text
37
+
38
+ # Warm up the model
39
+ self._warm_up()
40
+
41
+ def _warm_up(self):
42
+ """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
43
+ print("Warming up the model...")
44
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
45
+ audio, sr = torchaudio.load(ref_audio)
46
+ gen_text = "Warm-up text for the model."
47
+
48
+ # Pass the vocoder as an argument here
49
+ infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
50
+ print("Warm-up completed.")
51
+
52
+ def generate_stream(self, text, play_steps_in_s=0.5):
53
+ """Generate audio in chunks and yield them in real-time."""
54
+ # Preprocess the reference audio and text
55
+ ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
56
+
57
+ # Load reference audio
58
+ audio, sr = torchaudio.load(ref_audio)
59
+
60
+ # Run inference for the input text
61
+ audio_chunk, final_sample_rate, _ = infer_batch_process(
62
+ (audio, sr),
63
+ ref_text,
64
+ [text],
65
+ self.model,
66
+ self.vocoder,
67
+ device=self.device, # Pass vocoder here
68
+ )
69
+
70
+ # Break the generated audio into chunks and send them
71
+ chunk_size = int(final_sample_rate * play_steps_in_s)
72
+
73
+ for i in range(0, len(audio_chunk), chunk_size):
74
+ chunk = audio_chunk[i : i + chunk_size]
75
+
76
+ # Check if it's the final chunk
77
+ if i + chunk_size >= len(audio_chunk):
78
+ chunk = audio_chunk[i:]
79
+
80
+ # Avoid sending empty or repeated chunks
81
+ if len(chunk) == 0:
82
+ break
83
+
84
+ # Pack and send the audio chunk
85
+ packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
86
+ yield packed_audio
87
+
88
+ # Ensure that no final word is repeated by not resending partial chunks
89
+ if len(audio_chunk) % chunk_size != 0:
90
+ remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size) :]
91
+ packed_audio = struct.pack(f"{len(remaining_chunk)}f", *remaining_chunk)
92
+ yield packed_audio
93
+
94
+
95
+ def handle_client(client_socket, processor):
96
+ try:
97
+ while True:
98
+ # Receive data from the client
99
+ data = client_socket.recv(1024).decode("utf-8")
100
+ if not data:
101
+ break
102
+
103
+ try:
104
+ # The client sends the text input
105
+ text = data.strip()
106
+
107
+ # Generate and stream audio chunks
108
+ for audio_chunk in processor.generate_stream(text):
109
+ client_socket.sendall(audio_chunk)
110
+
111
+ # Send end-of-audio signal
112
+ client_socket.sendall(b"END_OF_AUDIO")
113
+
114
+ except Exception as inner_e:
115
+ print(f"Error during processing: {inner_e}")
116
+ traceback.print_exc() # Print the full traceback to diagnose the issue
117
+ break
118
+
119
+ except Exception as e:
120
+ print(f"Error handling client: {e}")
121
+ traceback.print_exc()
122
+ finally:
123
+ client_socket.close()
124
+
125
+
126
+ def start_server(host, port, processor):
127
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128
+ server.bind((host, port))
129
+ server.listen(5)
130
+ print(f"Server listening on {host}:{port}")
131
+
132
+ while True:
133
+ client_socket, addr = server.accept()
134
+ print(f"Accepted connection from {addr}")
135
+ client_handler = Thread(target=handle_client, args=(client_socket, processor))
136
+ client_handler.start()
137
+
138
+
139
+ if __name__ == "__main__":
140
+ try:
141
+ # Load the model and vocoder using the provided files
142
+ ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
143
+ vocab_file = "" # Add vocab file path if needed
144
+ ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
145
+ ref_text = ""
146
+
147
+ # Initialize the processor with the model and vocoder
148
+ processor = TTSStreamingProcessor(
149
+ ckpt_file=ckpt_file,
150
+ vocab_file=vocab_file,
151
+ ref_audio=ref_audio,
152
+ ref_text=ref_text,
153
+ dtype=torch.float32,
154
+ )
155
+
156
+ # Start the server
157
+ start_server("0.0.0.0", 9998, processor)
158
+ except KeyboardInterrupt:
159
+ gc.collect()