StormblessedKal commited on
Commit
e13b6d4
1 Parent(s): 5dfe293

add new parameter

Browse files
src/__pycache__/predict.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/predict.cpython-310.pyc and b/src/__pycache__/predict.cpython-310.pyc differ
 
src/__pycache__/rp_schema.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/rp_schema.cpython-310.pyc and b/src/__pycache__/rp_schema.cpython-310.pyc differ
 
src/predict.py CHANGED
@@ -204,7 +204,7 @@ class Predictor:
204
  return {"url": file_url}
205
 
206
 
207
- def predict(self,s3_url,passage,process_audio):
208
  output_dir = 'processed'
209
  gen_id = str(uuid.uuid4())
210
  os.makedirs(output_dir,exist_ok=True)
@@ -222,41 +222,38 @@ class Predictor:
222
  local_file_path = os.path.join(raw_dir,s3_key)
223
  self.download_file_from_s3(self.s3_client,bucket_name,s3_key,local_file_path)
224
  #voice_clone with styletts2
225
- model,sampler = self.model,self.sampler
226
- result = self.process_audio_file(local_file_path,passage,model,sampler)
227
- final_output = os.path.join(results_dir,f"{gen_id}-voice-clone-1.wav")
228
-
229
- sf.write(final_output,result,24000)
230
- if process_audio:
231
- (new_sr, wav1) = self._fn(final_output,"Midpoint",32,0.5)
232
- sf.write(final_output,wav1,new_sr)
233
-
234
- base_speaker_tts,tone_color_converter = self.base_speaker_tts,self.tone_color_converter
235
- reference_speaker = local_file_path
236
- target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir=openvoice_dir, vad=False)
237
- src_path = os.path.join(results_dir,f"{gen_id}-tmp.wav")
238
- openvoice_output = os.path.join(results_dir,f"{gen_id}-voice-clone-2.wav")
239
- base_speaker_tts.tts(passage,src_path,speaker='default',language='English',speed=1.0)
240
-
241
- source_se = torch.load(f'{self.ckpt_base}/en_default_se.pth').to(self.device)
242
- tone_color_converter.convert(audio_src_path=src_path,src_se=source_se,tgt_se=target_se,output_path=openvoice_output,message='')
243
- if process_audio:
244
- (new_sr, wav1) = self._fn(openvoice_output,"Midpoint",32,0.5)
245
- sf.write(openvoice_output,wav1,new_sr)
246
-
247
-
248
- mp3_final_output_1 = str(final_output).replace('wav','mp3')
249
- mp3_final_output_2 = str(openvoice_output).replace('wav','mp3')
250
- self.convert_wav_to_mp3(final_output,mp3_final_output_1)
251
- self.convert_wav_to_mp3(openvoice_output,mp3_final_output_2)
252
- print(mp3_final_output_1)
253
- print(mp3_final_output_2)
254
 
255
- self.upload_file_to_s3(mp3_final_output_1,'demovidelyusergenerations',f"{gen_id}-voice-clone-1.mp3")
256
- self.upload_file_to_s3(mp3_final_output_2,'demovidelyusergenerations',f"{gen_id}-voice-clone-2.mp3")
257
  shutil.rmtree(os.path.join(output_dir,gen_id))
258
- return {"voice_clone_1":f"https://demovidelyusergenerations.s3.amazonaws.com/{gen_id}-voice-clone-1.mp3",
259
- "voice_clone_2":f"https://demovidelyusergenerations.s3.amazonaws.com/{gen_id}-voice-clone-2.mp3"
260
  }
261
 
262
 
 
204
  return {"url": file_url}
205
 
206
 
207
+ def predict(self,s3_url,passage,process_audio,run_type='styletts2'):
208
  output_dir = 'processed'
209
  gen_id = str(uuid.uuid4())
210
  os.makedirs(output_dir,exist_ok=True)
 
222
  local_file_path = os.path.join(raw_dir,s3_key)
223
  self.download_file_from_s3(self.s3_client,bucket_name,s3_key,local_file_path)
224
  #voice_clone with styletts2
225
+ if run_type == 'styletts2':
226
+ model,sampler = self.model,self.sampler
227
+ result = self.process_audio_file(local_file_path,passage,model,sampler)
228
+ final_output = os.path.join(results_dir,f"{gen_id}-voice-clone-1.wav")
229
+
230
+ sf.write(final_output,result,24000)
231
+ if process_audio:
232
+ (new_sr, wav1) = self._fn(final_output,"Midpoint",32,0.5)
233
+ sf.write(final_output,wav1,new_sr)
234
+ mp3_final_output = str(final_output).replace('wav','mp3')
235
+ self.convert_wav_to_mp3(final_output,mp3_final_output)
236
+
237
+ if run_type == 'openvoice':
238
+ s_ref = self.compute_style(local_file_path, self.model)
239
+ base_speaker_tts,tone_color_converter = self.base_speaker_tts,self.tone_color_converter
240
+ reference_speaker = local_file_path
241
+ target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir=openvoice_dir, vad=False)
242
+ src_path = os.path.join(results_dir,f"{gen_id}-tmp.wav")
243
+ openvoice_output = os.path.join(results_dir,f"{gen_id}-voice-clone-2.wav")
244
+ base_speaker_tts.tts(passage,src_path,speaker='default',language='English',speed=1.0)
245
+
246
+ source_se = torch.load(f'{self.ckpt_base}/en_default_se.pth').to(self.device)
247
+ tone_color_converter.convert(audio_src_path=src_path,src_se=source_se,tgt_se=target_se,output_path=openvoice_output,message='')
248
+ if process_audio:
249
+ (new_sr, wav1) = self._fn(openvoice_output,"Midpoint",32,0.5)
250
+ sf.write(openvoice_output,wav1,new_sr)
251
+ mp3_final_output = str(openvoice_output).replace('wav','mp3')
252
+ self.convert_wav_to_mp3(openvoice_output,mp3_final_output)
 
253
 
254
+ self.upload_file_to_s3(mp3_final_output,'demovidelyusergenerations',f"{gen_id}-voice-clone.mp3")
 
255
  shutil.rmtree(os.path.join(output_dir,gen_id))
256
+ return {"voice_clone":f"https://demovidelyusergenerations.s3.amazonaws.com/{gen_id}-voice-clone.mp3"
 
257
  }
258
 
259
 
src/rp_handler.py CHANGED
@@ -50,7 +50,11 @@ def run_voice_clone_job(job):
50
  process_audio = False
51
 
52
  if method_type == 'voice_clone':
53
- result = MODEL.predict(s3_url,passage,process_audio)
 
 
 
 
54
  if method_type == 'voice_clone_with_emotions':
55
  result = MODEL.predict_with_emotions(s3_url,passage,process_audio)
56
  if method_type == 'voice_clone_with_multi_lang':
 
50
  process_audio = False
51
 
52
  if method_type == 'voice_clone':
53
+ run_type = job_input.get('run_type')
54
+ if run_type is not None:
55
+ result = MODEL.predict(s3_url,passage,process_audio,run_type)
56
+ else:
57
+ result = MODEL.predict(s3_url,passage,process_audio)
58
  if method_type == 'voice_clone_with_emotions':
59
  result = MODEL.predict_with_emotions(s3_url,passage,process_audio)
60
  if method_type == 'voice_clone_with_multi_lang':
src/rp_schema.py CHANGED
@@ -28,6 +28,10 @@ INPUT_VALIDATIONS = {
28
  'type': bool,
29
  'required': False,
30
  'default': False
 
 
 
 
 
31
  }
32
-
33
  }
 
28
  'type': bool,
29
  'required': False,
30
  'default': False
31
+ },
32
+ 'run_type': {
33
+ 'type': str,
34
+ 'required': False,
35
+ 'default': False
36
  }
 
37
  }