File size: 5,853 Bytes
d4cef17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import List, Tuple

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch

from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and the model
model, tokenizer = get_model_and_tokenizer()


def create_seed_string(genre: str = "OTHER") -> str:

    if genre == "RANDOM":
        seed_string = "PIECE_START"
    else:
        seed_string = f"PIECE_START GENRE={genre} TRACK_START"
    return seed_string


def get_instruments(text_sequence: str) -> List[str]:
    """
    Extracts the list of instruments from a text sequence.

    Args:
        text_sequence (str): The text sequence.

    Returns:
        List[str]: The list of instruments.
    """
    instruments = []
    parts = text_sequence.split()
    for part in parts:
        if part.startswith("INST="):
            if part[5:] == "DRUMS":
                instruments.append("Drums")
            else:
                index = int(part[5:])
                instruments.append(GM_INSTRUMENTS[index])
    return instruments


def generate_new_instrument(seed: str, temp: float = 0.75) -> str:

    seed_length = len(tokenizer.encode(seed))

    while True:
        # Encode the conditioning tokens.
        input_ids = tokenizer.encode(seed, return_tensors="pt")

        # Move the input_ids tensor to the same device as the model
        input_ids = input_ids.to(model.device)

        # Generate more tokens.
        eos_token_id = tokenizer.encode("TRACK_END")[0]
        generated_ids = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temp,
            eos_token_id=eos_token_id,
        )
        generated_sequence = tokenizer.decode(generated_ids[0])

        # Check if the generated sequence contains "NOTE_ON" beyond the seed
        new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
        if "NOTE_ON" in new_generated_sequence:
            return generated_sequence


def get_outputs_from_string(
    generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:

    instruments = get_instruments(generated_sequence)
    instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
    note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)

    synth = note_seq.fluidsynth
    array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
    int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
    fig = note_seq.plot_sequence(note_sequence, show_figure=False)
    num_tokens = str(len(generated_sequence.split()))
    audio = gr.make_waveform((SAMPLE_RATE, int16_data))
    note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
    return audio, "midi_ouput.mid", fig, instruments_str, num_tokens


def remove_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:

    # We split the song into tracks by splitting on 'TRACK_START'
    tracks = text_sequence.split("TRACK_START")
    # We keep all tracks except the last one
    modified_tracks = tracks[:-1]
    # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
    new_song = "TRACK_START".join(modified_tracks)

    if len(tracks) == 2:
        # There is only one instrument, so start from scratch
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_song
        )
    elif len(tracks) == 1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=""
        )
    else:
        audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
            new_song, qpm
        )

    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def regenerate_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:

    last_inst_index = text_sequence.rfind("INST=")
    if last_inst_index == -1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence="", qpm=qpm
        )
    else:
        # Take it from the last instrument and continue generation
        next_space_index = text_sequence.find(" ", last_inst_index)
        new_seed = text_sequence[:next_space_index]
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_seed, qpm=qpm
        )
    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def change_tempo(
    text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:

    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens


def generate_song(
    genre: str = "OTHER",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120,
) -> Tuple[ndarray, str, Figure, str, str, str]:

    if text_sequence == "":
        seed_string = create_seed_string(genre)
    else:
        seed_string = text_sequence

    generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        generated_sequence, qpm
    )
    return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens