Hemant0000 commited on
Commit
e036084
·
verified ·
1 Parent(s): 64f1736

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +132 -0
api.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import tqdm
4
+ from cached_path import cached_path
5
+
6
+ from model import DiT, UNetT
7
+ from model.utils import save_spectrogram
8
+
9
+ from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
+ from model.utils import seed_everything
11
+ import random
12
+ import sys
13
+
14
+
15
+ class F5TTS:
16
+ def __init__(
17
+ self,
18
+ model_type="F5-TTS",
19
+ ckpt_file="",
20
+ vocab_file="",
21
+ ode_method="euler",
22
+ use_ema=True,
23
+ local_path=None,
24
+ device=None,
25
+ ):
26
+ # Initialize parameters
27
+ self.final_wave = None
28
+ self.target_sample_rate = 24000
29
+ self.n_mel_channels = 100
30
+ self.hop_length = 256
31
+ self.target_rms = 0.1
32
+ self.seed = -1
33
+
34
+ # Set device
35
+ self.device = device or (
36
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
37
+ )
38
+
39
+ # Load models
40
+ self.load_vocoder_model(local_path)
41
+ self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
42
+
43
+ def load_vocoder_model(self, local_path):
44
+ self.vocos = load_vocoder(local_path is not None, local_path, self.device)
45
+
46
+ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
47
+ if model_type == "F5-TTS":
48
+ if not ckpt_file:
49
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
50
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
+ model_cls = DiT
52
+ elif model_type == "E2-TTS":
53
+ if not ckpt_file:
54
+ ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
55
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
56
+ model_cls = UNetT
57
+ else:
58
+ raise ValueError(f"Unknown model type: {model_type}")
59
+
60
+ self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
61
+
62
+ def export_wav(self, wav, file_wave, remove_silence=False):
63
+ sf.write(file_wave, wav, self.target_sample_rate)
64
+
65
+ if remove_silence:
66
+ remove_silence_for_generated_wav(file_wave)
67
+
68
+ def export_spectrogram(self, spect, file_spect):
69
+ save_spectrogram(spect, file_spect)
70
+
71
+ def infer(
72
+ self,
73
+ ref_file,
74
+ ref_text,
75
+ gen_text,
76
+ show_info=print,
77
+ progress=tqdm,
78
+ target_rms=0.1,
79
+ cross_fade_duration=0.15,
80
+ sway_sampling_coef=-1,
81
+ cfg_strength=2,
82
+ nfe_step=32,
83
+ speed=1.0,
84
+ fix_duration=None,
85
+ remove_silence=False,
86
+ file_wave=None,
87
+ file_spect=None,
88
+ seed=-1,
89
+ ):
90
+ if seed == -1:
91
+ seed = random.randint(0, sys.maxsize)
92
+ seed_everything(seed)
93
+ self.seed = seed
94
+ wav, sr, spect = infer_process(
95
+ ref_file,
96
+ ref_text,
97
+ gen_text,
98
+ self.ema_model,
99
+ show_info=show_info,
100
+ progress=progress,
101
+ target_rms=target_rms,
102
+ cross_fade_duration=cross_fade_duration,
103
+ nfe_step=nfe_step,
104
+ cfg_strength=cfg_strength,
105
+ sway_sampling_coef=sway_sampling_coef,
106
+ speed=speed,
107
+ fix_duration=fix_duration,
108
+ device=self.device,
109
+ )
110
+
111
+ if file_wave is not None:
112
+ self.export_wav(wav, file_wave, remove_silence)
113
+
114
+ if file_spect is not None:
115
+ self.export_spectrogram(spect, file_spect)
116
+
117
+ return wav, sr, spect
118
+
119
+
120
+ if __name__ == "__main__":
121
+ f5tts = F5TTS()
122
+
123
+ wav, sr, spect = f5tts.infer(
124
+ ref_file="tests/ref_audio/test_en_1_ref_short.wav",
125
+ ref_text="some call me nature, others call me mother nature.",
126
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
127
+ file_wave="tests/out.wav",
128
+ file_spect="tests/out.png",
129
+ seed=-1, # random seed = -1
130
+ )
131
+
132
+ print("seed :", f5tts.seed)