testspace / src /rp_handler.py
StormblessedKal's picture
try two methods
508fd98
raw
history blame
2.03 kB
"""
rp_handler.py for runpod worker
rp_debugger:
- Utility that provides additional debugging information.
The handler must be called with --rp_debugger flag to enable it.
"""
import base64
import tempfile
from rp_schema import INPUT_VALIDATIONS
from runpod.serverless.utils import download_files_from_urls, rp_cleanup, rp_debugger
from runpod.serverless.utils.rp_validator import validate
import runpod
import predict
MODEL = predict.Predictor()
MODEL.setup()
@rp_debugger.FunctionTimer
def run_voice_clone_job(job):
job_input = job['input']
method_type = job_input.get('method_type')
print(method_type)
if method_type not in ["create_voice","voice_clone","voice_clone_with_emotions","voice_clone_with_multi_lang"]:
return {"error":"Please set method_type: available options, create_voice, voice_clone, voice_clone_with_emotions,voice_clone_with_multi_lang"}
if method_type == "create_voice":
audio_base64 = job_input.get('audio_base64')
if audio_base64 is None:
return {"error":"Needs audio file as base64"}
cut_audio = job_input.get('cut_audio')
process_audio = job_input.get('process_audio')
print(process_audio)
if process_audio is None:
process_audio = False
if cut_audio is None:
cut_audio = 0
processed_urls = MODEL.createvoice(audio_base64,cut_audio,process_audio)
return processed_urls
else:
s3_url = job_input.get('s3_url')
passage = job_input.get('passage')
process_audio = job_input.get('process_audio')
print(process_audio)
if process_audio is None:
process_audio = False
if method_type == 'voice_clone':
result = MODEL.predict(s3_url,passage,process_audio)
if method_type == 'voice_clone_with_emotions':
result = MODEL.predict_with_emotions(s3_url,passage,process_audio)
return result
runpod.serverless.start({"handler": run_voice_clone_job})