File size: 11,862 Bytes
2a0ab09
 
 
 
 
 
 
 
 
 
 
 
 
c9d7e0a
2a0ab09
 
 
 
 
 
c9d7e0a
 
 
2a0ab09
c9d7e0a
 
2a0ab09
 
 
 
 
 
 
 
 
c909842
c9d7e0a
c909842
c9d7e0a
c909842
c9d7e0a
2a0ab09
c9d7e0a
2a0ab09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d7e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0ab09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d7e0a
 
 
 
 
 
 
2a0ab09
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d7e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0ab09
c9d7e0a
 
 
 
 
 
 
2a0ab09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0955573
c9d7e0a
2a0ab09
 
c9d7e0a
2a0ab09
 
 
 
 
 
 
 
0955573
2a0ab09
 
 
 
 
 
 
 
c9d7e0a
 
 
 
 
 
 
2a0ab09
c9d7e0a
2a0ab09
c9d7e0a
2a0ab09
 
 
 
 
 
c9d7e0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
from typing import List, Tuple

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch

from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and the model
model, tokenizer = get_model_and_tokenizer()

# Instruments
with open('instruments.json', 'r') as f:
    instruments = json.load(f)


def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str:
    """
    Creates a seed string for generating a new piece.

    Args:
        genre (str, optional): The genre of the piece. Defaults to "OTHER".

    Returns:
        str: The seed string.
    """
    if genre == "RANDOM" and artist == "RANDOM":
        seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}"
    elif genre == "RANDOM" and artist != "RANDOM":
        seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}"
    elif genre != "RANDOM" and artist == "RANDOM":
        seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}"
    else:
        seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}"
    return seed_string


def get_instruments(text_sequence: str) -> List[str]:
    """
    Extracts the list of instruments from a text sequence.

    Args:
        text_sequence (str): The text sequence.

    Returns:
        List[str]: The list of instruments.
    """
    instruments = []
    parts = text_sequence.split()
    for part in parts:
        if part.startswith("INST="):
            if part[5:] == "DRUMS":
                instruments.append("Drums")
            else:
                index = int(part[5:])
                instruments.append(GM_INSTRUMENTS[index])
    return instruments


def change_last_instrument( text_sequence: str, 
                           instrument: str, 
                           temp: float = 0.75, 
                           qpm: int = 120
                          ) -> Tuple[ndarray, str, Figure, str, str, str]:


    instrument_idx = instruments.index(instrument)
    #Drums
    if instrument_idx == 0:
        instrument_idx='DRUMS'
    else:
        instrument_idx = str(instrument_idx-1)
    text_sequence = text_sequence.split()
    for token_idx in reversed(range(len(text_sequence))):
        if "INST=" in text_sequence[token_idx]:
            text_sequence[token_idx] = f"INST={instrument_idx}"
            break
    text_sequence = (' ').join(text_sequence)
    #print(text_sequence)

    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm
    )
    # print(type(audio),audio)
    # print(type(midi_file),midi_file) 
    # print(type(fig),fig)
    # print(type(instruments_str),instruments_str)
    # print(type(num_tokens),num_tokens)
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
            


def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
    """
    Generates a new instrument sequence from a given seed and temperature.

    Args:
        seed (str): The seed string for the generation.
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.

    Returns:
        str: The generated instrument sequence.
    """
    seed_length = len(tokenizer.encode(seed))

    while True:
        # Encode the conditioning tokens.
        input_ids = tokenizer.encode(seed, return_tensors="pt")

        # Move the input_ids tensor to the same device as the model
        input_ids = input_ids.to(model.device)

        # Generate more tokens.
        eos_token_id = tokenizer.encode("TRACK_END")[0]
        generated_ids = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temp,
            eos_token_id=eos_token_id,
        )
        generated_sequence = tokenizer.decode(generated_ids[0])

        # Check if the generated sequence contains "NOTE_ON" beyond the seed
        new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
        if "NOTE_ON" in new_generated_sequence:
            return generated_sequence


def get_outputs_from_string(
    generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
    """
    Converts a generated sequence into various output formats including audio, MIDI, plot, etc.

    Args:
        generated_sequence (str): The generated sequence of tokens.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
                                               instruments string, and number of tokens string.
    """
    instruments = get_instruments(generated_sequence)
    instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
    note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)

    synth = note_seq.fluidsynth
    array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
    int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
    fig = note_seq.plot_sequence(note_sequence, show_figure=False)
    num_tokens = str(len(generated_sequence.split()))
    audio = gr.make_waveform((SAMPLE_RATE, int16_data))
    note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
    return audio, "midi_ouput.mid", fig, instruments_str, num_tokens


def remove_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Removes the last instrument from a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    # We split the song into tracks by splitting on 'TRACK_START'
    tracks = text_sequence.split("TRACK_START")
    # We keep all tracks except the last one
    modified_tracks = tracks[:-1]
    # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
    new_song = "TRACK_START".join(modified_tracks)

    if len(tracks) == 2:
        # There is only one instrument, so start from scratch
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_song
        )
    elif len(tracks) == 1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=""
        )
    else:
        audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
            new_song, qpm
        )

    return audio, midi_file, fig, instruments_str, new_song, num_tokens


    genre: str = "OTHER",
    artist: str = "KATE_BUSH",
    instrument: str = "Acoustic Grand Piano",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120
    
def regenerate_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Regenerates the last instrument in a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """

    def remove_last_track(text_sequence):
        tracks = text_sequence.split("TRACK_START")
        # We keep all tracks except the last one
        useful_tracks = tracks[:-1]
        # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
        text_sequence = "TRACK_START".join(useful_tracks)
        return text_sequence
    
    #last_inst_index = text_sequence.rfind("INST=")
    
    for token in reversed(text_sequence.split()):
        if 'INST=' in token:
            instrument_id = token.split('=')[1]
            break
    
    if instrument_id=="DRUMS":
        instrument="Drums"
    else:
        instrument=instruments[int(instrument_id)+1]# Index 0 instrument is 'Acoustic Grand Piano' for rendering:https://soundprogramming.net/file-formats/general-midi-instrument-list/#google_vignette

    new_seed = remove_last_track(text_sequence=text_sequence)
    
    audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
        instrument=instrument,text_sequence=new_seed, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def change_tempo(
    text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Changes the tempo of a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int): The new quarter notes per minute.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, text sequence, and number of tokens string.
    """
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens


def generate_song(
    genre: str = "OTHER",
    artist: str = "KATE_BUSH",
    instrument: str = "Acoustic Grand Piano",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Generates a song given a genre, temperature, initial text sequence, and tempo.

    Args:
        model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
        tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
        genre (str, optional): The genre of the song. Defaults to "OTHER".
        artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH".
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
        text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, generated song string, and number of tokens string.
    """
    instrument = instruments.index(instrument)
    #Drums
    if instrument == 0:
        instrument='DRUMS'
    else:
        instrument = str(instrument-1)
        
    if text_sequence == "":
        seed_string = create_seed_string(genre, artist, instrument)
    else:
        seed_string = text_sequence + " TRACK_START INST=" + instrument

    generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        generated_sequence, qpm
    )
    return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens