Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import demucs.api | |
import os | |
import spaces | |
import subprocess | |
from pydub import AudioSegment | |
from typing import Tuple, Dict, List | |
# check if cuda is available | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
# check if sox is installed and install it if necessary | |
try: | |
subprocess.run(["sox", "--version"], check=True, capture_output=True) | |
except FileNotFoundError: | |
print("sox is not installed. trying to install it now...") | |
try: | |
subprocess.run(["apt-get", "update"], check=True) | |
subprocess.run(["apt-get", "install", "-y", "sox"], check=True) | |
print("sox has been installed.") | |
except subprocess.CalledProcessError as e: | |
print(f"error installing sox: {e}") | |
print("please install sox manually or try adding the following repository to your sources list:") | |
print("deb http://deb.debian.org/debian stretch main contrib non-free") | |
exit(1) | |
# define the inference function | |
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, str]: | |
""" | |
performs inference using demucs and mixes the selected stems. | |
args: | |
audio_file: the audio file to separate. | |
model_name: the name of the demucs model to use. | |
vocals: whether to include vocals in the mix. | |
drums: whether to include drums in the mix. | |
bass: whether to include bass in the mix. | |
other: whether to include other instruments in the mix. | |
mp3: whether to save the output as mp3. | |
mp3_bitrate: the bitrate of the output mp3 file. | |
returns: | |
a tuple containing the path to the mixed audio file and the separation log. | |
""" | |
# initialize demucs separator | |
separator: demucs.api.Separator = demucs.api.Separator(model=model_name) | |
# separate audio file and capture log | |
import io | |
log_stream = io.StringIO() | |
origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream) | |
separation_log = log_stream.getvalue() | |
# get the output file paths | |
output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0]) | |
os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist | |
stems: Dict[str, str] = {} | |
for stem, source in separated.items(): | |
stem_path: str = os.path.join(output_dir, f"{stem}.wav") | |
demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate) | |
stems[stem] = stem_path | |
# mix the selected stems | |
selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include] | |
if not selected_stems: | |
raise gr.Error("please select at least one stem to mix.") | |
output_file: str = os.path.join(output_dir, "mixed.wav") | |
if len(selected_stems) == 1: | |
# if only one stem is selected, just copy it | |
os.rename(selected_stems[0], output_file) | |
else: | |
# otherwise, use pydub to mix the stems | |
mixed_audio: AudioSegment = AudioSegment.empty() | |
for stem_path in selected_stems: | |
mixed_audio += AudioSegment.from_wav(stem_path) | |
mixed_audio.export(output_file, format="wav") | |
# automatically convert to mp3 if requested | |
if mp3: | |
mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3" | |
mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k") | |
output_file = mp3_output_file # update output_file to the mp3 file | |
return output_file, separation_log | |
# define the gradio interface | |
iface: gr.Interface = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Audio(type="filepath"), | |
gr.Dropdown(["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"], label="model name", value="htdemucs_ft"), # set default value | |
gr.Checkbox(label="vocals", value=True), | |
gr.Checkbox(label="drums", value=True), | |
gr.Checkbox(label="bass", value=True), | |
gr.Checkbox(label="other", value=True), | |
gr.Checkbox(label="save as mp3", value=False), # set default value to false | |
gr.Slider(128, 320, step=32, label="mp3 bitrate", visible=False), # set visible to false initially | |
], | |
outputs=[ | |
gr.Audio(type="filepath"), | |
gr.Textbox(label="separation log", lines=10), | |
], | |
title="demucs music source separation and mixing", | |
description="separate vocals, drums, bass, and other instruments from your music using demucs and mix the selected stems.", | |
) | |
# make mp3 bitrate slider visible only when "save as mp3" is checked | |
iface.inputs[-2].change(fn=lambda mp3: gr.update(visible=mp3), inputs=iface.inputs[-2], outputs=iface.inputs[-1]) | |
# launch the gradio interface | |
iface.launch() | |