File size: 5,180 Bytes
c318a73
 
6a24aec
 
4336e0a
 
6a24aec
 
 
 
a99ae87
4336e0a
6a24aec
c318a73
 
6a24aec
a99ae87
 
 
 
c318a73
 
6a24aec
061e8b0
6a24aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c318a73
 
 
 
 
 
6a24aec
 
 
 
97a428f
c318a73
 
6a24aec
c318a73
 
 
 
 
 
 
 
 
 
 
6a24aec
c318a73
 
061e8b0
 
 
 
 
 
 
 
 
 
c318a73
 
 
 
 
 
 
 
6a24aec
 
 
 
 
 
 
 
c318a73
 
 
 
 
6a24aec
 
 
 
 
 
 
 
 
 
 
c318a73
 
 
 
 
97a428f
 
 
 
 
 
c318a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a24aec
c318a73
 
 
061e8b0
3e1d74a
c318a73
 
 
 
 
 
 
 
 
6a24aec
 
4336e0a
 
 
 
 
 
 
 
6a24aec
4336e0a
 
 
 
 
 
 
 
 
 
6a24aec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import time
from pathlib import Path
from tempfile import NamedTemporaryFile

import basic_pitch
import basic_pitch.inference
import gradio as gr
import torch
from audiocraft.data.audio import audio_write
from audiocraft.data.audio_utils import convert_audio
from audiocraft.models import MusicGen, AudioGen
from basic_pitch import ICASSP_2022_MODEL_PATH
from transformers import AutoModelForSeq2SeqLM


def load_model(version="facebook/musicgen-melody"):
    if version in ["facebook/audiogen-medium"]:
        return AudioGen.get_pretrained(version)
    else:
        return MusicGen.get_pretrained(version)


def _do_predictions(
    model_file,
    model,
    texts,
    melodies,
    duration,
    progress=False,
    gradio_progress=None,
    target_sr=32000,
    target_ac=1,
    **gen_kwargs,
):
    print(
        "new batch",
        len(texts),
        texts,
        [None if m is None else (m[0], m[1].shape) for m in melodies],
    )
    be = time.time()
    processed_melodies = []
    for melody in melodies:
        if melody is None:
            processed_melodies.append(None)
        else:
            sr, melody = (
                melody[0],
                torch.from_numpy(melody[1]).to(model.device).float().t(),
            )
            print(f"Input audio sample rate is {sr}")
            if melody.dim() == 1:
                melody = melody[None]
            melody = melody[..., : int(sr * duration)]
            melody = convert_audio(melody, sr, target_sr, target_ac)
            processed_melodies.append(melody)

    try:
        if any(m is not None for m in processed_melodies):
            # melody condition
            outputs = model.generate_with_chroma(
                descriptions=texts,
                melody_wavs=processed_melodies,
                melody_sample_rate=target_sr,
                progress=progress,
                return_tokens=False,
            )
        else:
            if model_file == "facebook/audiogen-medium":
                # audio condition
                outputs = model.generate(
                    texts,
                    progress=progress
                )
            else:
                # text only
                outputs = model.generate(texts, progress=progress)

    except RuntimeError as e:
        raise gr.Error("Error while generating " + e.args[0])
    outputs = outputs.detach().cpu().float()
    pending_videos = []
    out_wavs = []
    for output in outputs:
        with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
            audio_write(
                file.name,
                output,
                model.sample_rate,
                strategy="loudness",
                loudness_headroom_db=16,
                loudness_compressor=True,
                add_suffix=False,
            )
            out_wavs.append(file.name)
    print("generation finished", len(texts), time.time() - be)
    return out_wavs


def predict(
    model_path,
    text,
    melody,
    duration,
    topk,
    topp,
    temperature,
    target_sr,
    progress=gr.Progress(),
):
    global INTERRUPTING
    global USE_DIFFUSION
    INTERRUPTING = False
    progress(0, desc="Loading model...")
    model_path = model_path.strip()
    # if model_path:
    #     if not Path(model_path).exists():
    #         raise gr.Error(f"Model path {model_path} doesn't exist.")
    #     if not Path(model_path).is_dir():
    #         raise gr.Error(f"Model path {model_path} must be a folder containing "
    #                        "state_dict.bin and compression_state_dict_.bin.")
    if temperature < 0:
        raise gr.Error("Temperature must be >= 0.")
    if topk < 0:
        raise gr.Error("Topk must be non-negative.")
    if topp < 0:
        raise gr.Error("Topp must be non-negative.")

    topk = int(topk)
    model = load_model(model_path)

    max_generated = 0

    def _progress(generated, to_generate):
        nonlocal max_generated
        max_generated = max(generated, max_generated)
        progress((min(max_generated, to_generate), to_generate))
        if INTERRUPTING:
            raise gr.Error("Interrupted.")

    model.set_custom_progress_callback(_progress)

    wavs = _do_predictions(
        model_path,
        model,
        [text],
        [melody],
        duration,
        progress=True,
        target_ac=1,
        target_sr=target_sr,
        top_k=topk,
        top_p=topp,
        temperature=temperature,
        gradio_progress=progress,
    )
    return wavs[0]


def transcribe(audio_path):
    # model_output, midi_data, note_events = predict("generated_0.wav")
    model_output, midi_data, note_events = basic_pitch.inference.predict(
        audio_path=audio_path,
        model_or_model_path=ICASSP_2022_MODEL_PATH,
    )

    with NamedTemporaryFile("wb", suffix=".mid", delete=False) as file:
        try:
            midi_data.write(file)
            print(f"midi file saved to {file.name}")
        except Exception as e:
            print(f"Error while writing midi file: {e}")
            raise e

    return gr.DownloadButton(
        value=file.name, label=f"Download MIDI file {file.name}", visible=True
    )