Spaces:
Sleeping
Sleeping
import time | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
import basic_pitch | |
import basic_pitch.inference | |
import gradio as gr | |
import torch | |
from audiocraft.data.audio import audio_write | |
from audiocraft.data.audio_utils import convert_audio | |
from audiocraft.models import AudioGen, MusicGen, MAGNeT | |
from basic_pitch import ICASSP_2022_MODEL_PATH | |
# from transformers import AutoModelForSeq2SeqLM | |
from concurrent.futures import ProcessPoolExecutor | |
import typing as tp | |
import warnings | |
import json | |
import ast | |
import torchaudio | |
MODEL = None | |
def load_model(version='facebook/musicgen-large'): | |
global MODEL | |
if MODEL is None or MODEL.name != version: | |
del MODEL | |
MODEL = None # in case loading would crash | |
print("Loading model", version) | |
if "magnet" in version: | |
MODEL = MAGNeT.get_pretrained(version) | |
elif "musicgen" in version: | |
MODEL = MusicGen.get_pretrained(version) | |
elif "musiclang" in version: | |
# TODO: Implement MusicLang | |
pass | |
elif "audiogen" in version: | |
MODEL = AudioGen.get_pretrained(version) | |
else: | |
raise ValueError("Invalid model version") | |
return MODEL | |
pool = ProcessPoolExecutor(4) | |
class FileCleaner: | |
def __init__(self, file_lifetime: float = 3600): | |
self.file_lifetime = file_lifetime | |
self.files = [] | |
def add(self, path: tp.Union[str, Path]): | |
self._cleanup() | |
self.files.append((time.time(), Path(path))) | |
def _cleanup(self): | |
now = time.time() | |
for time_added, path in list(self.files): | |
if now - time_added > self.file_lifetime: | |
if path.exists(): | |
path.unlink() | |
self.files.pop(0) | |
else: | |
break | |
file_cleaner = FileCleaner() | |
def inference_musicgen_text_to_music(model, configs, text, num_outputs=1): | |
model.set_generation_params( | |
**configs | |
) | |
descriptions = [text for _ in range(num_outputs)] | |
output = model.generate(descriptions=descriptions ,progress=True, return_tokens=False) | |
return output | |
def inference_musicgen_continuation(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1): | |
model.set_generation_params( | |
**configs | |
) | |
# melody, prompt_sr = torchaudio.load(prompt_waveform) | |
# descriptions = [text for _ in range(num_outputs)] | |
# prompt = [prompt_waveform for _ in range(num_outputs)] | |
output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=False) | |
return output | |
def inference_musicgen_melody_condition(model, configs, text, prompt_waveform, prompt_sr, num_outputs=1): | |
model.set_generation_params(**configs) | |
descriptions = [text for _ in range(num_outputs)] | |
output = model.generate_with_chroma( | |
descriptions=descriptions, | |
melody_wavs=prompt_waveform, | |
melody_sample_rate=prompt_sr, | |
progress=True, | |
return_tokens=False | |
) | |
return output | |
def inference_magnet(model, configs, text, num_outputs=1): | |
model.set_generation_params( | |
**configs | |
) | |
descriptions = [text for _ in range(num_outputs)] | |
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
return output | |
def inference_magnet_audio(model, configs, text, num_outputs=1): | |
model.set_generation_params( | |
**configs | |
) | |
descriptions = [text for _ in range(num_outputs)] | |
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
return output | |
def inference_audiogen(model, configs, text, num_outputs=1): | |
model.set_generation_params( | |
**configs | |
) | |
descriptions = [text for _ in range(num_outputs)] | |
output = model.generate(descriptions=descriptions, progress=True, return_tokens=False) | |
return output | |
def inference_musiclang(): | |
# TODO: Implement MusicLang | |
pass | |
def process_audio(gr_audio, prompt_duration, model): | |
# audio, sr = torch.from_numpy(gr_audio[1]).to(model.device).float().t(), gr_audio[0] | |
audio, sr = torchaudio.load(gr_audio) | |
audio = audio[..., :int(prompt_duration * sr)] | |
return audio, sr | |
_MODEL_INFERENCES = { | |
"facebook/musicgen-small": inference_musicgen_text_to_music, | |
"facebook/musicgen-medium": inference_musicgen_text_to_music, | |
"facebook/musicgen-large": inference_musicgen_text_to_music, | |
"facebook/musicgen-melody": inference_musicgen_melody_condition, | |
"facebook/musicgen-melody-large": inference_musicgen_melody_condition, | |
"facebook/magnet-small-10secs": inference_magnet, | |
"facebook/magnet-medium-10secs": inference_magnet, | |
"facebook/magnet-small-30secs": inference_magnet, | |
"facebook/magnet-medium-30secs": inference_magnet, | |
"facebook/audio-magnet-small": inference_magnet_audio, | |
"facebook/audio-magnet-medium": inference_magnet_audio, | |
"facebook/audiogen-medium": inference_audiogen, | |
"musicgen-continuation": inference_musicgen_continuation, | |
} | |
def _do_predictions( | |
model_file, | |
model, | |
text, | |
melody = None, | |
mel_sample_rate=None, | |
progress=False, | |
num_generations=1, | |
**gen_kwargs, | |
): | |
print( | |
"new generation", | |
text, | |
None if melody is None else melody.shape | |
) | |
be = time.time() | |
try: | |
if melody is not None: | |
# melody condition or continuation | |
if 'melody' in model_file: | |
# melody condition - musicgen-melody, musicgen-melody-large | |
inderence_func = _MODEL_INFERENCES[model_file] | |
else: | |
# melody continuation | |
inderence_func = _MODEL_INFERENCES['musicgen-continuation'] | |
outputs = inderence_func(model, gen_kwargs, text, melody, mel_sample_rate, num_generations) | |
else: | |
# text-to-music, text-to-sound | |
inderence_func = _MODEL_INFERENCES[model_file] | |
outputs = inderence_func(model, gen_kwargs, text, num_generations) | |
except RuntimeError as e: | |
raise gr.Error("Error while generating " + e.args[0]) | |
outputs = outputs.detach().cpu().float() | |
out_audios = [] | |
video_processes = [] | |
for output in outputs: | |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
audio_write( | |
file.name, | |
output, | |
model.sample_rate, | |
strategy="loudness", | |
loudness_headroom_db=16, | |
loudness_compressor=True, | |
add_suffix=False, | |
) | |
# video_processes.append(pool.submit(make_waveform, file.name)) | |
out_audios.append(file.name) | |
file_cleaner.add(file.name) | |
# out_videos = [video.result() for video in video_processes] | |
# for video in out_videos: | |
# file_cleaner.add(video) | |
print("generation finished", len(outputs), time.time() - be) | |
return out_audios | |
def make_waveform(*args, **kwargs): | |
# Further remove some warnings. | |
be = time.time() | |
with warnings.catch_warnings(): | |
warnings.simplefilter('ignore') | |
out = gr.make_waveform(*args, **kwargs) | |
print("Make a video took", time.time() - be) | |
return out | |
def predict( | |
model_version, | |
generation_configs, | |
prompt_text=None, | |
prompt_wav=None, | |
num_generations=1, | |
progress=gr.Progress(), | |
): | |
global INTERRUPTING | |
INTERRUPTING = False | |
progress(0, desc="Loading model...") | |
def _progress(generated, to_generate): | |
nonlocal max_generated | |
max_generated = max(generated, max_generated) | |
progress((min(max_generated, to_generate), to_generate)) | |
if INTERRUPTING: | |
raise gr.Error("Interrupted.") | |
model = load_model(model_version) | |
model.set_custom_progress_callback(_progress) | |
if isinstance(generation_configs, str): | |
generation_configs = ast.literal_eval(generation_configs) | |
max_generated = 0 | |
if prompt_wav is not None: | |
melody, mel_sample_rate = process_audio(prompt_wav, generation_configs['duration'], model) | |
else: | |
melody, mel_sample_rate = None, None | |
audios = _do_predictions( | |
model_version, | |
model, | |
prompt_text, | |
melody, | |
mel_sample_rate, | |
progress=True, | |
num_generations = num_generations, | |
**generation_configs, | |
) | |
return audios | |
def transcribe(audio_path): | |
""" | |
Transcribe an audio file to MIDI using the basic_pitch model. | |
""" | |
# model_output, midi_data, note_events = predict("generated_0.wav") | |
tmp_paths = ast.literal_eval(audio_path) | |
download_buttons = [] | |
for audio_path in tmp_paths: | |
model_output, midi_data, note_events = basic_pitch.inference.predict( | |
audio_path=audio_path, | |
model_or_model_path=ICASSP_2022_MODEL_PATH, | |
) | |
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file: | |
try: | |
midi_data.write(file) | |
print(f"midi file saved to {file.name}") | |
except Exception as e: | |
print(f"Error while writing midi file: {e}") | |
raise e | |
download_buttons.append(gr.DownloadButton( | |
value=file.name, label=f"Download MIDI file {file.name}", visible=True | |
)) | |
file_cleaner.add(file.name) | |
return download_buttons | |