Spaces:
Sleeping
Sleeping
File size: 5,180 Bytes
c318a73 6a24aec 4336e0a 6a24aec a99ae87 4336e0a 6a24aec c318a73 6a24aec a99ae87 c318a73 6a24aec 061e8b0 6a24aec c318a73 6a24aec 97a428f c318a73 6a24aec c318a73 6a24aec c318a73 061e8b0 c318a73 6a24aec c318a73 6a24aec c318a73 97a428f c318a73 6a24aec c318a73 061e8b0 3e1d74a c318a73 6a24aec 4336e0a 6a24aec 4336e0a 6a24aec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import time
from pathlib import Path
from tempfile import NamedTemporaryFile
import basic_pitch
import basic_pitch.inference
import gradio as gr
import torch
from audiocraft.data.audio import audio_write
from audiocraft.data.audio_utils import convert_audio
from audiocraft.models import MusicGen, AudioGen
from basic_pitch import ICASSP_2022_MODEL_PATH
from transformers import AutoModelForSeq2SeqLM
def load_model(version="facebook/musicgen-melody"):
if version in ["facebook/audiogen-medium"]:
return AudioGen.get_pretrained(version)
else:
return MusicGen.get_pretrained(version)
def _do_predictions(
model_file,
model,
texts,
melodies,
duration,
progress=False,
gradio_progress=None,
target_sr=32000,
target_ac=1,
**gen_kwargs,
):
print(
"new batch",
len(texts),
texts,
[None if m is None else (m[0], m[1].shape) for m in melodies],
)
be = time.time()
processed_melodies = []
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = (
melody[0],
torch.from_numpy(melody[1]).to(model.device).float().t(),
)
print(f"Input audio sample rate is {sr}")
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., : int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
try:
if any(m is not None for m in processed_melodies):
# melody condition
outputs = model.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=False,
)
else:
if model_file == "facebook/audiogen-medium":
# audio condition
outputs = model.generate(
texts,
progress=progress
)
else:
# text only
outputs = model.generate(texts, progress=progress)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name,
output,
model.sample_rate,
strategy="loudness",
loudness_headroom_db=16,
loudness_compressor=True,
add_suffix=False,
)
out_wavs.append(file.name)
print("generation finished", len(texts), time.time() - be)
return out_wavs
def predict(
model_path,
text,
melody,
duration,
topk,
topp,
temperature,
target_sr,
progress=gr.Progress(),
):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
progress(0, desc="Loading model...")
model_path = model_path.strip()
# if model_path:
# if not Path(model_path).exists():
# raise gr.Error(f"Model path {model_path} doesn't exist.")
# if not Path(model_path).is_dir():
# raise gr.Error(f"Model path {model_path} must be a folder containing "
# "state_dict.bin and compression_state_dict_.bin.")
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
model = load_model(model_path)
max_generated = 0
def _progress(generated, to_generate):
nonlocal max_generated
max_generated = max(generated, max_generated)
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
model.set_custom_progress_callback(_progress)
wavs = _do_predictions(
model_path,
model,
[text],
[melody],
duration,
progress=True,
target_ac=1,
target_sr=target_sr,
top_k=topk,
top_p=topp,
temperature=temperature,
gradio_progress=progress,
)
return wavs[0]
def transcribe(audio_path):
# model_output, midi_data, note_events = predict("generated_0.wav")
model_output, midi_data, note_events = basic_pitch.inference.predict(
audio_path=audio_path,
model_or_model_path=ICASSP_2022_MODEL_PATH,
)
with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
try:
midi_data.write(file)
print(f"midi file saved to {file.name}")
except Exception as e:
print(f"Error while writing midi file: {e}")
raise e
return gr.DownloadButton(
value=file.name, label=f"Download MIDI file {file.name}", visible=True
)
|