ZiyueJiang commited on
Commit
f447f4e
·
1 Parent(s): d2c9151

code update for duration of ZeroGPU

Browse files
Files changed (2) hide show
  1. tts/gradio_api.py +3 -4
  2. tts/infer_cli.py +6 -0
tts/gradio_api.py CHANGED
@@ -20,9 +20,7 @@ import gradio as gr
20
  import traceback
21
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
 
23
- import spaces
24
 
25
- @spaces.GPU(duration=120)
26
  def model_worker(input_queue, output_queue, device_id):
27
  device = None
28
  if device_id is not None:
@@ -39,8 +37,9 @@ def model_worker(input_queue, output_queue, device_id):
39
  cut_wav(wav_path, max_len=28)
40
  with open(wav_path, 'rb') as file:
41
  file_content = file.read()
42
- resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
43
- wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
 
44
  output_queue.put(wav_bytes)
45
  except Exception as e:
46
  traceback.print_exc()
 
20
  import traceback
21
  from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
 
 
23
 
 
24
  def model_worker(input_queue, output_queue, device_id):
25
  device = None
26
  if device_id is not None:
 
37
  cut_wav(wav_path, max_len=28)
38
  with open(wav_path, 'rb') as file:
39
  file_content = file.read()
40
+ wav_bytes = infer_pipe.forward_zerogpu(file_content, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
41
+ # resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
42
+ # wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
43
  output_queue.put(wav_bytes)
44
  except Exception as e:
45
  traceback.print_exc()
tts/infer_cli.py CHANGED
@@ -252,6 +252,12 @@ class MegaTTS3DiTInfer():
252
  wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
253
  return to_wav_bytes(wav_pred, self.sr)
254
 
 
 
 
 
 
 
255
 
256
  if __name__ == '__main__':
257
  parser = argparse.ArgumentParser()
 
252
  wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
253
  return to_wav_bytes(wav_pred, self.sr)
254
 
255
+ @spaces.GPU(duration=120)
256
+ def forward_zerogpu(self, file_content, latent_file, inp_text, time_step, p_w, t_w):
257
+ resource_context = self.preprocess(file_content, latent_file)
258
+ wav_bytes = self.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w)
259
+ return wav_bytes
260
+
261
 
262
  if __name__ == '__main__':
263
  parser = argparse.ArgumentParser()