{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "API_KEY = \"YHI5AIKD4BEJ5M0C6U06I00OMHMT6LS0L7T2JD4T\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import runpod\n", "import requests\n", "from voice_generation import generate_wav\n", "import boto3\n", "import os\n", "import uuid\n", "from pydub import AudioSegment\n", "import time\n", "import subprocess\n", "\n", "\n", "AWS_ACCESS_KEY_ID = \"AKIA6KIDWRLG42X2ZH7M\"\n", "AWS_SECRET_ACCESS_KEY = \"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\"\n", "\n", "\n", "models = {\n", " 'kanye': 'weights/kanye.pth',\n", " 'rose-bp': 'weights/rose-bp.pth',\n", " 'jungkook': 'weights/jungkook.pth',\n", " 'iu': 'weights/iu.pth',\n", " 'drake': 'weights/drake.pth',\n", " 'ariana-grande': 'weights/ariana-grande.pth'\n", "}\n", "\n", "\n", "print('run handler')\n", "\n", "\n", "def split_audio():\n", " subprocess.call([\"env/bin/python\", \"deezer-split.py\"])\n", "\n", "\n", "def combine_audio(voice_path, instrumental_path):\n", " audio1 = AudioSegment.from_file(instrumental_path, format=\"mp3\")\n", " audio2 = AudioSegment.from_file(voice_path, format=\"mp3\")\n", " \n", " length = max(len(audio1), len(audio2))\n", " audio1 = audio1 + AudioSegment.silent(duration=length - len(audio1))\n", " audio2 = audio2 + AudioSegment.silent(duration=length - len(audio2))\n", " \n", " combined = audio1.overlay(audio2)\n", " \n", " combined.export(\"combined.mp3\", format=\"mp3\")\n", "\n", "\n", "def upload_file_to_s3(local_file_path, s3_file_path):\n", " bucket_name = 'voice-gen-audios'\n", " s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)\n", " try:\n", " s3.upload_file(local_file_path, bucket_name, s3_file_path)\n", " return {\"url\": f\"https://{bucket_name}.s3.eu-north-1.amazonaws.com/{s3_file_path}\"}\n", " except boto3.exceptions.S3UploadFailedError as e:\n", " return {\"error\": f\"failed to upload file {local_file_path} to s3 as {s3_file_path}\"}\n", "\n", "\n", "def clean_up_files(remove_voice_model=False):\n", " files = [\n", " \"song.mp3\",\n", " \"accompaniment.mp3\",\n", " \"vocals.mp3\",\n", " \"output_vocal.mp3\",\n", " \"combined.mp3\",\n", " ]\n", " if remove_voice_model:\n", " files.append(\"voice_model.pth\")\n", " for file in files:\n", " try:\n", " os.remove(file)\n", " except FileNotFoundError:\n", " return {\"error\": f\"failed to remove file {file}\"}\n", " return {\"success\": \"files removed successfully\"}\n", "\n", "\n", "def get_voice_model(event):\n", " voice_model_id = event[\"input\"].get(\"voice_model_id\", \"\")\n", " voice_model_url = event[\"input\"].get(\"voice_model_url\", \"\")\n", " \n", " if not voice_model_url and not voice_model_id:\n", " return {\"error\": \"voice_model_url or voice_model_id is required\"}\n", "\n", " if voice_model_id and voice_model_id not in models:\n", " return {\"error\": \"model not found in pre-loaded models\"}\n", " \n", " if voice_model_id:\n", " return {\"model_path\": models[voice_model_id]}\n", " \n", " print(\"downloading voice_model\")\n", " voice_model_response = requests.get(voice_model_url)\n", " if voice_model_response.status_code != 200:\n", " return {\"error\": f\"failed to download voice_model, error: {voice_model_response.text}\"}\n", " \n", " with open(\"voice_model.pth\", \"wb\") as f:\n", " f.write(voice_model_response.content)\n", "\n", " return {\"model_path\": \"voice_model.pth\"}\n", "\n", "\n", "def handler(event):\n", " print(event)\n", " file_id = str(uuid.uuid4())\n", " user_id = event[\"input\"].get(\"user_id\", \"not provided\")\n", " \n", " if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:\n", " return {\"error\": \"AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are missing from environment variables\"}\n", " \n", " voice_model = get_voice_model(event)\n", " if \"error\" in voice_model:\n", " return voice_model.get(\"error\")\n", " \n", " song_url = event[\"input\"].get(\"song_url\", \"\")\n", "\n", " if song_url == \"\":\n", " return {\"error\": \"voice_url is required\"}\n", "\n", " song_file = requests.get(song_url)\n", " if song_file.status_code != 200:\n", " return {\"error\": \"failed to download song_file\"}\n", " \n", " with open(\"song.mp3\", \"wb\") as f:\n", " f.write(song_file.content)\n", "\n", " splitting_start = time.time() # remove after testing\n", " split_audio()\n", " splitting_end = time.time() # remove after testing\n", " time_taken_splitting = splitting_end - splitting_start # remove after testing\n", " print(f\"splitting took {time_taken_splitting} seconds\") # remove after testing\n", "\n", " if not os.path.exists(\"accompaniment.mp3\") or not os.path.exists(\"vocals.mp3\"):\n", " return {\"error\": \"failed to split song\"}\n", "\n", "\n", " \n", " song_instruments = upload_file_to_s3(\"accompaniment.mp3\", f\"{file_id}-split-accompaniment.mp3\")\n", " song_vocals = upload_file_to_s3(\"vocals.mp3\", f\"{file_id}-split-vocals.mp3\")\n", " if \"error\" in song_instruments:\n", " return song_instruments.get(\"error\")\n", " if \"error\" in song_vocals:\n", " return song_vocals.get(\"error\")\n", "\n", "\n", " gemeration_start = time.time() # remove after testing\n", "\n", " generation = generate_wav(\n", " audio_file='vocals.mp3',\n", " method='pm',\n", " index_rate=0.6,\n", " output_file='output_vocal.mp3',\n", " model_path=voice_model.get(\"model_path\")\n", " )\n", " generation_end = time.time() # remove after testing\n", " time_taken_generation = generation_end - gemeration_start # remove after testing\n", " print(f\"generation took {time_taken_generation} seconds\") # remove after testing\n", " \n", " if \"error\" in generation:\n", " return generation.get(\"error\")\n", "\n", " print(\"before combining\")\n", " combine_audio(\"output_vocal.mp3\", \"accompaniment.mp3\")\n", " print(\"after combining\")\n", "\n", " if not os.path.exists(\"combined.mp3\"):\n", " return {\"error\": \"failed to combine audio\"}\n", "\n", " combined = upload_file_to_s3(\"combined.mp3\", f\"{file_id}.mp3\")\n", " output_voice = upload_file_to_s3(\"output_vocal.mp3\", f\"{file_id}-generated-voical.mp3\")\n", "\n", " if combined_error := combined.get(\"error\"):\n", " return combined_error\n", " \n", " if output_voice_error := output_voice.get(\"error\"):\n", " return output_voice_error\n", " \n", " combined_url = combined.get(\"url\")\n", " output_voice_url = output_voice.get(\"url\")\n", "\n", " need_to_remove_voice_model = False\n", " if voice_model.get(\"model_path\") == \"voice_model.pth\":\n", " need_to_remove_voice_model = True\n", " cleanup_result = clean_up_files(need_to_remove_voice_model)\n", " if cleanup_error := cleanup_result.get(\"error\"):\n", " return cleanup_error\n", "\n", " return {\n", " \"combined_url\": combined_url,\n", " \"output_voice_url\": output_voice_url,\n", " \"user_id\": user_id,\n", " \"time_taken_splitting\": time_taken_splitting, # remove after testing\n", " \"time_taken_generation\": time_taken_generation, # remove after testing\n", " }\n", "\n", "\n", "\n", "result = handler({\n", " \"input\": {\n", " \"song_url\": \"https://voice-gen-audios.s3.eu-north-1.amazonaws.com/combined_trimmed_original.mp3\",\n", " \"user_id\": \"test_user\",\n", " \"voice_model_id\": \"kanye\"\n", " # \"voice_model_url\": \"https://rvc-models.s3.amazonaws.com/lilbaby.pth\"\n", " }\n", "})\n", "\n", "print(result)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }