Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/api.py +1 -1
- src/f5_tts/infer/infer_cli.py +13 -2
- src/f5_tts/infer/utils_infer.py +3 -3
src/f5_tts/api.py
CHANGED
@@ -119,7 +119,7 @@ class F5TTS:
|
|
119 |
seed_everything(seed)
|
120 |
self.seed = seed
|
121 |
|
122 |
-
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text
|
123 |
|
124 |
wav, sr, spec = infer_process(
|
125 |
ref_file,
|
|
|
119 |
seed_everything(seed)
|
120 |
self.seed = seed
|
121 |
|
122 |
+
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
123 |
|
124 |
wav, sr, spec = infer_process(
|
125 |
ref_file,
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -162,6 +162,11 @@ parser.add_argument(
|
|
162 |
type=float,
|
163 |
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
165 |
args = parser.parse_args()
|
166 |
|
167 |
|
@@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
|
202 |
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
203 |
speed = args.speed or config.get("speed", speed)
|
204 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
|
|
205 |
|
206 |
|
207 |
# patches for pip pkg user
|
@@ -239,7 +245,9 @@ if vocoder_name == "vocos":
|
|
239 |
elif vocoder_name == "bigvgan":
|
240 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
241 |
|
242 |
-
vocoder = load_vocoder(
|
|
|
|
|
243 |
|
244 |
|
245 |
# load TTS model
|
@@ -270,7 +278,9 @@ if not ckpt_file:
|
|
270 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
271 |
|
272 |
print(f"Using {model}...")
|
273 |
-
ema_model = load_model(
|
|
|
|
|
274 |
|
275 |
|
276 |
# inference process
|
@@ -326,6 +336,7 @@ def main():
|
|
326 |
sway_sampling_coef=sway_sampling_coef,
|
327 |
speed=speed,
|
328 |
fix_duration=fix_duration,
|
|
|
329 |
)
|
330 |
generated_audio_segments.append(audio_segment)
|
331 |
|
|
|
162 |
type=float,
|
163 |
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
164 |
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--device",
|
167 |
+
type=str,
|
168 |
+
help="Specify the device to run on",
|
169 |
+
)
|
170 |
args = parser.parse_args()
|
171 |
|
172 |
|
|
|
207 |
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
208 |
speed = args.speed or config.get("speed", speed)
|
209 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
210 |
+
device = args.device
|
211 |
|
212 |
|
213 |
# patches for pip pkg user
|
|
|
245 |
elif vocoder_name == "bigvgan":
|
246 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
247 |
|
248 |
+
vocoder = load_vocoder(
|
249 |
+
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
250 |
+
)
|
251 |
|
252 |
|
253 |
# load TTS model
|
|
|
278 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
279 |
|
280 |
print(f"Using {model}...")
|
281 |
+
ema_model = load_model(
|
282 |
+
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
283 |
+
)
|
284 |
|
285 |
|
286 |
# inference process
|
|
|
336 |
sway_sampling_coef=sway_sampling_coef,
|
337 |
speed=speed,
|
338 |
fix_duration=fix_duration,
|
339 |
+
device=device,
|
340 |
)
|
341 |
generated_audio_segments.append(audio_segment)
|
342 |
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
|
149 |
dtype = (
|
150 |
torch.float16
|
151 |
if "cuda" in device
|
152 |
-
and torch.cuda.get_device_properties(device).major >=
|
153 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
154 |
else torch.float32
|
155 |
)
|
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
|
186 |
dtype = (
|
187 |
torch.float16
|
188 |
if "cuda" in device
|
189 |
-
and torch.cuda.get_device_properties(device).major >=
|
190 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
191 |
else torch.float32
|
192 |
)
|
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
|
289 |
# preprocess reference audio and text
|
290 |
|
291 |
|
292 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print
|
293 |
show_info("Converting audio...")
|
294 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
295 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
149 |
dtype = (
|
150 |
torch.float16
|
151 |
if "cuda" in device
|
152 |
+
and torch.cuda.get_device_properties(device).major >= 7
|
153 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
154 |
else torch.float32
|
155 |
)
|
|
|
186 |
dtype = (
|
187 |
torch.float16
|
188 |
if "cuda" in device
|
189 |
+
and torch.cuda.get_device_properties(device).major >= 7
|
190 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
191 |
else torch.float32
|
192 |
)
|
|
|
289 |
# preprocess reference audio and text
|
290 |
|
291 |
|
292 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
293 |
show_info("Converting audio...")
|
294 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
295 |
aseg = AudioSegment.from_file(ref_audio_orig)
|