testspace / src /rp_handler.py
StormblessedKal's picture
support s3
af9920b
raw
history blame
2.79 kB
"""
rp_handler.py for runpod worker
rp_debugger:
- Utility that provides additional debugging information.
The handler must be called with --rp_debugger flag to enable it.
"""
import base64
import tempfile
from rp_schema import INPUT_VALIDATIONS
from runpod.serverless.utils import download_files_from_urls, rp_cleanup, rp_debugger
from runpod.serverless.utils.rp_validator import validate
import runpod
import predict
MODEL = predict.Predictor()
MODEL.setup()
@rp_debugger.FunctionTimer
def run_voice_clone_job(job):
job_input = job['input']
method_type = job_input.get('method_type')
print(method_type)
if method_type not in ["create_voice","voice_clone","voice_clone_with_emotions","voice_clone_with_multi_lang"]:
return {"error":"Please set method_type: available options, create_voice, voice_clone, voice_clone_with_emotions,voice_clone_with_multi_lang"}
if method_type == "create_voice":
s3_url = job_input.get("s3_url")
audio_base64 = job_input.get('audio_base64')
if audio_base64 is None and s3_url is None:
return {"error":"set audio_base64 or s3_url"}
cut_audio = job_input.get('cut_audio')
process_audio = job_input.get('process_audio')
print(process_audio)
if process_audio is None:
process_audio = False
if cut_audio is None:
cut_audio = 0
processed_urls = MODEL.createvoice(s3_url,audio_base64,cut_audio,process_audio)
return processed_urls
else:
s3_url = job_input.get('s3_url')
passage = job_input.get('passage')
process_audio = job_input.get('process_audio')
print(process_audio)
if process_audio is None:
process_audio = False
output_extension = job_input.get('output_extension')
if output_extension == None:
output_extension = "mp3"
if output_extension not in ["mp3","ogg"]:
return {"error" : "only supports mp3 and ogg as output_extension"}
print(output_extension)
if method_type == 'voice_clone':
run_type = job_input.get('run_type')
if run_type is not None:
result = MODEL.predict(s3_url,passage,process_audio,output_extension,run_type)
else:
result = MODEL.predict(s3_url,passage,process_audio,output_extension)
if method_type == 'voice_clone_with_emotions':
result = MODEL.predict_with_emotions(s3_url,passage,process_audio,output_extension)
if method_type == 'voice_clone_with_multi_lang':
result = MODEL.predict_with_multi_lang(s3_url,passage,process_audio,output_extension)
return result
runpod.serverless.start({"handler": run_voice_clone_job})