suric's picture
update apps and examples
48860c6
import time
from pathlib import Path
from tempfile import NamedTemporaryFile
import basic_pitch
import basic_pitch.inference
import gradio as gr
import torch
from audiocraft.data.audio import audio_write
from audiocraft.data.audio_utils import convert_audio
from audiocraft.models import AudioGen, MusicGen, MAGNeT
from basic_pitch import ICASSP_2022_MODEL_PATH
# from transformers import AutoModelForSeq2SeqLM
from concurrent.futures import ProcessPoolExecutor
import typing as tp
import warnings
import json
import ast
import torchaudio
MODEL = None
def load_model(version='facebook/musicgen-large'):
global MODEL
if MODEL is None or MODEL.name != version:
del MODEL
MODEL = None # in case loading would crash
print("Loading model", version)
if "magnet" in version:
MODEL = MAGNeT.get_pretrained(version)
elif "musicgen" in version:
MODEL = MusicGen.get_pretrained(version)
elif "musiclang" in version:
# TODO: Implement MusicLang
pass
elif "audiogen" in version:
MODEL = AudioGen.get_pretrained(version)
else:
raise ValueError("Invalid model version")
return MODEL
pool = ProcessPoolExecutor(4)
class FileCleaner:
def __init__(self, file_lifetime: float = 3600):
self.file_lifetime = file_lifetime
self.files = []
def add(self, path: tp.Union[str, Path]):
self._cleanup()
self.files.append((time.time(), Path(path)))
def _cleanup(self):
now = time.time()
for time_added, path in list(self.files):
if now - time_added > self.file_lifetime:
if path.exists():
path.unlink()
self.files.pop(0)
else:
break
file_cleaner = FileCleaner()
def inference_musicgen_text_to_music(model, configs, text, num_outputs=1):
model.set_generation_params(
**configs
)
descriptions = [text for _ in range(num_outputs)]
output = model.generate(descriptions=descriptions ,progress=True, return_tokens=False)
return output
def inference_musicgen_continuation(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
model.set_generation_params(
**configs
)
# melody, prompt_sr = torchaudio.load(prompt_waveform)
# descriptions = [text for _ in range(num_outputs)]
# prompt = [prompt_waveform for _ in range(num_outputs)]
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=False)
return output
def inference_musicgen_melody_condition(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1):
model.set_generation_params(**configs)
descriptions = [text for _ in range(num_outputs)]
output = model.generate_with_chroma(
descriptions=descriptions,
melody_wavs=prompt_waveform,
melody_sample_rate=prompt_sr,
progress=True,
return_tokens=False
)
return output
def inference_magnet(model, configs, text, num_outputs=1):
model.set_generation_params(
**configs
)
descriptions = [text for _ in range(num_outputs)]
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
return output
def inference_magnet_audio(model, configs, text, num_outputs=1):
model.set_generation_params(
**configs
)
descriptions = [text for _ in range(num_outputs)]
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
return output
def inference_audiogen(model, configs, text, num_outputs=1):
model.set_generation_params(
**configs
)
descriptions = [text for _ in range(num_outputs)]
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False)
return output
def inference_musiclang():
# TODO: Implement MusicLang
pass
def process_audio(gr_audio, prompt_duration, model):
# audio, sr = torch.from_numpy(gr_audio[1]).to(model.device).float().t(), gr_audio[0]
audio, sr = torchaudio.load(gr_audio)
audio = audio[..., :int(prompt_duration * sr)]
return audio, sr
_MODEL_INFERENCES = {
"facebook/musicgen-small": inference_musicgen_text_to_music,
"facebook/musicgen-medium": inference_musicgen_text_to_music,
"facebook/musicgen-large": inference_musicgen_text_to_music,
"facebook/musicgen-melody": inference_musicgen_melody_condition,
"facebook/musicgen-melody-large": inference_musicgen_melody_condition,
"facebook/magnet-small-10secs": inference_magnet,
"facebook/magnet-medium-10secs": inference_magnet,
"facebook/magnet-small-30secs": inference_magnet,
"facebook/magnet-medium-30secs": inference_magnet,
"facebook/audio-magnet-small": inference_magnet_audio,
"facebook/audio-magnet-medium": inference_magnet_audio,
"facebook/audiogen-medium": inference_audiogen,
"musicgen-continuation": inference_musicgen_continuation,
}
def _do_predictions(
model_file,
model,
text,
melody = None,
mel_sample_rate=None,
progress=False,
num_generations=1,
**gen_kwargs,
):
print(
"new generation",
text,
None if melody is None else melody.shape
)
be = time.time()
try:
if melody is not None:
# melody condition or continuation
if 'melody' in model_file:
# melody condition - musicgen-melody, musicgen-melody-large
inderence_func = _MODEL_INFERENCES[model_file]
else:
# melody continuation
inderence_func = _MODEL_INFERENCES['musicgen-continuation']
outputs = inderence_func(model, gen_kwargs, text, melody, mel_sample_rate, num_generations)
else:
# text-to-music, text-to-sound
inderence_func = _MODEL_INFERENCES[model_file]
outputs = inderence_func(model, gen_kwargs, text, num_generations)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
outputs = outputs.detach().cpu().float()
out_audios = []
video_processes = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name,
output,
model.sample_rate,
strategy="loudness",
loudness_headroom_db=16,
loudness_compressor=True,
add_suffix=False,
)
# video_processes.append(pool.submit(make_waveform, file.name))
out_audios.append(file.name)
file_cleaner.add(file.name)
# out_videos = [video.result() for video in video_processes]
# for video in out_videos:
# file_cleaner.add(video)
print("generation finished", len(outputs), time.time() - be)
return out_audios
def make_waveform(*args, **kwargs):
# Further remove some warnings.
be = time.time()
with warnings.catch_warnings():
warnings.simplefilter('ignore')
out = gr.make_waveform(*args, **kwargs)
print("Make a video took", time.time() - be)
return out
def predict(
model_version,
generation_configs,
prompt_text=None,
prompt_wav=None,
num_generations=1,
progress=gr.Progress(),
):
global INTERRUPTING
INTERRUPTING = False
progress(0, desc="Loading model...")
def _progress(generated, to_generate):
nonlocal max_generated
max_generated = max(generated, max_generated)
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
model = load_model(model_version)
model.set_custom_progress_callback(_progress)
if isinstance(generation_configs, str):
generation_configs = ast.literal_eval(generation_configs)
max_generated = 0
if prompt_wav is not None:
melody, mel_sample_rate = process_audio(prompt_wav, generation_configs['duration'], model)
else:
melody, mel_sample_rate = None, None
audios = _do_predictions(
model_version,
model,
prompt_text,
melody,
mel_sample_rate,
progress=True,
num_generations = num_generations,
**generation_configs,
)
return audios
def transcribe(audio_path):
"""
Transcribe an audio file to MIDI using the basic_pitch model.
"""
# model_output, midi_data, note_events = predict("generated_0.wav")
tmp_paths = ast.literal_eval(audio_path)
download_buttons = []
for audio_path in tmp_paths:
model_output, midi_data, note_events = basic_pitch.inference.predict(
audio_path=audio_path,
model_or_model_path=ICASSP_2022_MODEL_PATH,
)
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
try:
midi_data.write(file)
print(f"midi file saved to {file.name}")
except Exception as e:
print(f"Error while writing midi file: {e}")
raise e
download_buttons.append(gr.DownloadButton(
value=file.name, label=f"Download MIDI file {file.name}", visible=True
))
file_cleaner.add(file.name)
return download_buttons