mrfakename commited on
Commit
5f7ec69
·
verified ·
1 Parent(s): 43bc5dc

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/api.py CHANGED
@@ -119,7 +119,7 @@ class F5TTS:
119
  seed_everything(seed)
120
  self.seed = seed
121
 
122
- ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
123
 
124
  wav, sr, spec = infer_process(
125
  ref_file,
 
119
  seed_everything(seed)
120
  self.seed = seed
121
 
122
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
123
 
124
  wav, sr, spec = infer_process(
125
  ref_file,
src/f5_tts/infer/infer_cli.py CHANGED
@@ -162,6 +162,11 @@ parser.add_argument(
162
  type=float,
163
  help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
164
  )
 
 
 
 
 
165
  args = parser.parse_args()
166
 
167
 
@@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
202
  sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
203
  speed = args.speed or config.get("speed", speed)
204
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
 
205
 
206
 
207
  # patches for pip pkg user
@@ -239,7 +245,9 @@ if vocoder_name == "vocos":
239
  elif vocoder_name == "bigvgan":
240
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
241
 
242
- vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
 
 
243
 
244
 
245
  # load TTS model
@@ -270,7 +278,9 @@ if not ckpt_file:
270
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
271
 
272
  print(f"Using {model}...")
273
- ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
 
 
274
 
275
 
276
  # inference process
@@ -326,6 +336,7 @@ def main():
326
  sway_sampling_coef=sway_sampling_coef,
327
  speed=speed,
328
  fix_duration=fix_duration,
 
329
  )
330
  generated_audio_segments.append(audio_segment)
331
 
 
162
  type=float,
163
  help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
164
  )
165
+ parser.add_argument(
166
+ "--device",
167
+ type=str,
168
+ help="Specify the device to run on",
169
+ )
170
  args = parser.parse_args()
171
 
172
 
 
207
  sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
208
  speed = args.speed or config.get("speed", speed)
209
  fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
210
+ device = args.device
211
 
212
 
213
  # patches for pip pkg user
 
245
  elif vocoder_name == "bigvgan":
246
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
247
 
248
+ vocoder = load_vocoder(
249
+ vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
250
+ )
251
 
252
 
253
  # load TTS model
 
278
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
279
 
280
  print(f"Using {model}...")
281
+ ema_model = load_model(
282
+ model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
283
+ )
284
 
285
 
286
  # inference process
 
336
  sway_sampling_coef=sway_sampling_coef,
337
  speed=speed,
338
  fix_duration=fix_duration,
339
+ device=device,
340
  )
341
  generated_audio_segments.append(audio_segment)
342
 
src/f5_tts/infer/utils_infer.py CHANGED
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
149
  dtype = (
150
  torch.float16
151
  if "cuda" in device
152
- and torch.cuda.get_device_properties(device).major >= 6
153
  and not torch.cuda.get_device_name().endswith("[ZLUDA]")
154
  else torch.float32
155
  )
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
186
  dtype = (
187
  torch.float16
188
  if "cuda" in device
189
- and torch.cuda.get_device_properties(device).major >= 6
190
  and not torch.cuda.get_device_name().endswith("[ZLUDA]")
191
  else torch.float32
192
  )
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
289
  # preprocess reference audio and text
290
 
291
 
292
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
293
  show_info("Converting audio...")
294
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
295
  aseg = AudioSegment.from_file(ref_audio_orig)
 
149
  dtype = (
150
  torch.float16
151
  if "cuda" in device
152
+ and torch.cuda.get_device_properties(device).major >= 7
153
  and not torch.cuda.get_device_name().endswith("[ZLUDA]")
154
  else torch.float32
155
  )
 
186
  dtype = (
187
  torch.float16
188
  if "cuda" in device
189
+ and torch.cuda.get_device_properties(device).major >= 7
190
  and not torch.cuda.get_device_name().endswith("[ZLUDA]")
191
  else torch.float32
192
  )
 
289
  # preprocess reference audio and text
290
 
291
 
292
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
293
  show_info("Converting audio...")
294
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
295
  aseg = AudioSegment.from_file(ref_audio_orig)