asigalov61 commited on
Commit
c9d9ce3
·
verified ·
1 Parent(s): de9ee66

Upload 2 files

Browse files
Melody2Song_Seq2Seq_Music_Transformer.ipynb ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "VGrGd6__l5ch"
7
+ },
8
+ "source": [
9
+ "# Melody2Song Seq2Seq Music Transformer (ver. 1.0)\n",
10
+ "\n",
11
+ "***\n",
12
+ "\n",
13
+ "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
14
+ "\n",
15
+ "***\n",
16
+ "\n",
17
+ "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
18
+ "\n",
19
+ "***\n",
20
+ "\n",
21
+ "#### Project Los Angeles\n",
22
+ "\n",
23
+ "#### Tegridy Code 2024\n",
24
+ "\n",
25
+ "***"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "shLrgoXdl5cj"
32
+ },
33
+ "source": [
34
+ "# (GPU CHECK)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "X3rABEpKCO02",
42
+ "cellView": "form"
43
+ },
44
+ "outputs": [],
45
+ "source": [
46
+ "# @title NVIDIA GPU Check\n",
47
+ "!nvidia-smi"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {
53
+ "id": "0RcVC4btl5ck"
54
+ },
55
+ "source": [
56
+ "# (SETUP ENVIRONMENT)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {
63
+ "id": "viHgEaNACPTs",
64
+ "cellView": "form"
65
+ },
66
+ "outputs": [],
67
+ "source": [
68
+ "# @title Install requirements\n",
69
+ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools\n",
70
+ "!pip install einops\n",
71
+ "!pip install torch-summary\n",
72
+ "!apt install fluidsynth"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "id": "DzCOZU_gBiQV",
80
+ "cellView": "form"
81
+ },
82
+ "outputs": [],
83
+ "source": [
84
+ "# @title Load all needed modules\n",
85
+ "\n",
86
+ "print('=' * 70)\n",
87
+ "print('Loading needed modules...')\n",
88
+ "print('=' * 70)\n",
89
+ "\n",
90
+ "import os\n",
91
+ "import pickle\n",
92
+ "import random\n",
93
+ "import secrets\n",
94
+ "import tqdm\n",
95
+ "import math\n",
96
+ "import torch\n",
97
+ "\n",
98
+ "import matplotlib.pyplot as plt\n",
99
+ "\n",
100
+ "from torchsummary import summary\n",
101
+ "\n",
102
+ "%cd /content/tegridy-tools/tegridy-tools/\n",
103
+ "\n",
104
+ "import TMIDIX\n",
105
+ "from midi_to_colab_audio import midi_to_colab_audio\n",
106
+ "\n",
107
+ "%cd /content/tegridy-tools/tegridy-tools/X-Transformer\n",
108
+ "\n",
109
+ "from x_transformer_1_23_2 import *\n",
110
+ "\n",
111
+ "%cd /content/\n",
112
+ "\n",
113
+ "import random\n",
114
+ "\n",
115
+ "from sklearn import metrics\n",
116
+ "\n",
117
+ "from IPython.display import Audio, display\n",
118
+ "\n",
119
+ "from huggingface_hub import hf_hub_download\n",
120
+ "\n",
121
+ "from google.colab import files\n",
122
+ "\n",
123
+ "print('=' * 70)\n",
124
+ "print('Done')\n",
125
+ "print('=' * 70)\n",
126
+ "print('Torch version:', torch.__version__)\n",
127
+ "print('=' * 70)\n",
128
+ "print('Enjoy! :)')\n",
129
+ "print('=' * 70)"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "source": [
135
+ "# (SETUP DATA AND MODEL)"
136
+ ],
137
+ "metadata": {
138
+ "id": "SQ1_7P4bLdtB"
139
+ }
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "source": [
144
+ "#@title Load Melody2Song Seq2Seq Music Trnasofmer Data and Pre-Trained Model\n",
145
+ "\n",
146
+ "#@markdown Model precision option\n",
147
+ "\n",
148
+ "model_precision = \"bfloat16\" # @param [\"bfloat16\", \"float16\"]\n",
149
+ "\n",
150
+ "plot_tokens_embeddings = True # @param {type:\"boolean\"}\n",
151
+ "\n",
152
+ "print('=' * 70)\n",
153
+ "print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...')\n",
154
+ "print('=' * 70)\n",
155
+ "\n",
156
+ "data_path = '/content'\n",
157
+ "\n",
158
+ "if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'):\n",
159
+ " print('Data file already exists...')\n",
160
+ "\n",
161
+ "else:\n",
162
+ " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n",
163
+ " repo_type='space',\n",
164
+ " filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle',\n",
165
+ " local_dir=data_path,\n",
166
+ " )\n",
167
+ "\n",
168
+ "print('=' * 70)\n",
169
+ "seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')\n",
170
+ "\n",
171
+ "print('=' * 70)\n",
172
+ "print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...')\n",
173
+ "print('Please wait...')\n",
174
+ "print('=' * 70)\n",
175
+ "\n",
176
+ "full_path_to_models_dir = \"/content\"\n",
177
+ "\n",
178
+ "model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth'\n",
179
+ "model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name\n",
180
+ "num_layers = 24\n",
181
+ "if os.path.isfile(model_path):\n",
182
+ " print('Model already exists...')\n",
183
+ "\n",
184
+ "else:\n",
185
+ " hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',\n",
186
+ " repo_type='space',\n",
187
+ " filename=model_checkpoint_file_name,\n",
188
+ " local_dir=full_path_to_models_dir,\n",
189
+ " )\n",
190
+ "\n",
191
+ "\n",
192
+ "print('=' * 70)\n",
193
+ "print('Instantiating model...')\n",
194
+ "\n",
195
+ "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
196
+ "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
197
+ "device_type = 'cuda'\n",
198
+ "\n",
199
+ "if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():\n",
200
+ " dtype = 'bfloat16'\n",
201
+ "else:\n",
202
+ " dtype = 'float16'\n",
203
+ "\n",
204
+ "if model_precision == 'float16':\n",
205
+ " dtype = 'float16'\n",
206
+ "\n",
207
+ "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n",
208
+ "ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)\n",
209
+ "\n",
210
+ "SEQ_LEN = 2560\n",
211
+ "PAD_IDX = 514\n",
212
+ "\n",
213
+ "# instantiate the model\n",
214
+ "\n",
215
+ "model = TransformerWrapper(\n",
216
+ " num_tokens = PAD_IDX+1,\n",
217
+ " max_seq_len = SEQ_LEN,\n",
218
+ " attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True)\n",
219
+ ")\n",
220
+ "\n",
221
+ "model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)\n",
222
+ "\n",
223
+ "model.cuda()\n",
224
+ "print('=' * 70)\n",
225
+ "\n",
226
+ "print('Loading model checkpoint...')\n",
227
+ "\n",
228
+ "model.load_state_dict(torch.load(model_path))\n",
229
+ "print('=' * 70)\n",
230
+ "\n",
231
+ "model.eval()\n",
232
+ "\n",
233
+ "print('Done!')\n",
234
+ "print('=' * 70)\n",
235
+ "\n",
236
+ "print('Model will use', dtype, 'precision...')\n",
237
+ "print('=' * 70)\n",
238
+ "\n",
239
+ "# Model stats\n",
240
+ "print('Model summary...')\n",
241
+ "summary(model)\n",
242
+ "\n",
243
+ "if plot_tokens_embeddings:\n",
244
+ "\n",
245
+ " tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n",
246
+ "\n",
247
+ " cos_sim = metrics.pairwise_distances(\n",
248
+ " tok_emb, metric='cosine'\n",
249
+ " )\n",
250
+ " plt.figure(figsize=(7, 7))\n",
251
+ " plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n",
252
+ " im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n",
253
+ " plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n",
254
+ " plt.xlabel(\"Position\")\n",
255
+ " plt.ylabel(\"Position\")\n",
256
+ " plt.tight_layout()\n",
257
+ " plt.plot()\n",
258
+ " plt.savefig(\"/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")"
259
+ ],
260
+ "metadata": {
261
+ "cellView": "form",
262
+ "id": "z7QLJ6FajxPA"
263
+ },
264
+ "execution_count": null,
265
+ "outputs": []
266
+ },
267
+ {
268
+ "cell_type": "markdown",
269
+ "source": [
270
+ "# (LOAD SEED MELODY)"
271
+ ],
272
+ "metadata": {
273
+ "id": "NdJ1_A8gNoV3"
274
+ }
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {
280
+ "id": "AIvb6MmSO9R3",
281
+ "cellView": "form"
282
+ },
283
+ "outputs": [],
284
+ "source": [
285
+ "# @title Load desired seed melody\n",
286
+ "\n",
287
+ "#@markdown NOTE: If custom MIDI file is not provided, sample seed melody will be used instead\n",
288
+ "\n",
289
+ "full_path_to_custom_seed_melody_MIDI_file = \"/content/tegridy-tools/tegridy-tools/seed-melody.mid\" # @param {type:\"string\"}\n",
290
+ "sample_seed_melody_number = 0 # @param {type:\"slider\", min:0, max:203664, step:1}\n",
291
+ "\n",
292
+ "print('=' * 70)\n",
293
+ "print('Loading seed melody...')\n",
294
+ "print('=' * 70)\n",
295
+ "\n",
296
+ "if full_path_to_custom_seed_melody_MIDI_file != '':\n",
297
+ "\n",
298
+ " #===============================================================================\n",
299
+ " # Raw single-track ms score\n",
300
+ "\n",
301
+ " raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file)\n",
302
+ "\n",
303
+ " #===============================================================================\n",
304
+ " # Enhanced score notes\n",
305
+ "\n",
306
+ " escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]\n",
307
+ "\n",
308
+ " #===============================================================================\n",
309
+ " # Augmented enhanced score notes\n",
310
+ "\n",
311
+ " escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))\n",
312
+ "\n",
313
+ " cscore = TMIDIX.chordify_score([1000, escore_notes])\n",
314
+ "\n",
315
+ " fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore])\n",
316
+ "\n",
317
+ " melody = []\n",
318
+ "\n",
319
+ " pe = fixed_mel_score[0]\n",
320
+ "\n",
321
+ " for s in fixed_mel_score:\n",
322
+ "\n",
323
+ " dtime = max(0, min(127, s[1]-pe[1]))\n",
324
+ " dur = max(1, min(127, s[2]))\n",
325
+ " ptc = max(1, min(127, s[4]))\n",
326
+ "\n",
327
+ " chan = 1\n",
328
+ "\n",
329
+ " melody.extend([dtime, dur+128, (128 * chan)+ptc+256])\n",
330
+ "\n",
331
+ " pe = s\n",
332
+ "\n",
333
+ " if len(melody) >= 192:\n",
334
+ " melody = [512] + melody[:192] + [513]\n",
335
+ "\n",
336
+ " else:\n",
337
+ " mult = math.ceil(192 / len(melody))\n",
338
+ " melody = melody * mult\n",
339
+ " melody = [512] + melody[:192] + [513]\n",
340
+ "\n",
341
+ " print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file)\n",
342
+ " print('=' * 70)\n",
343
+ "\n",
344
+ "else:\n",
345
+ " melody = seed_melodies_data[sample_seed_melody_number]\n",
346
+ " print('Loaded sample seed melody #', sample_seed_melody_number)\n",
347
+ " print('=' * 70)\n",
348
+ "\n",
349
+ "print('Sample melody INTs:', melody[:10])\n",
350
+ "print('=' * 70)\n",
351
+ "print('Done!')\n",
352
+ "print('=' * 70)"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "metadata": {
358
+ "id": "feXay_Ed7mG5"
359
+ },
360
+ "source": [
361
+ "# (GENERATE)"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": null,
367
+ "metadata": {
368
+ "id": "naf65RxUXwDg",
369
+ "cellView": "form"
370
+ },
371
+ "outputs": [],
372
+ "source": [
373
+ "# @title Generate song from melody\n",
374
+ "\n",
375
+ "melody_MIDI_patch_number = 40 # @param {type:\"slider\", min:0, max:127, step:1}\n",
376
+ "accompaniment_MIDI_patch_number = 0 # @param {type:\"slider\", min:0, max:127, step:1}\n",
377
+ "number_of_tokens_to_generate = 900 # @param {type:\"slider\", min:15, max:2354, step:3}\n",
378
+ "number_of_batches_to_generate = 4 # @param {type:\"slider\", min:1, max:16, step:1}\n",
379
+ "top_k_value = 25 # @param {type:\"slider\", min:1, max:50, step:1}\n",
380
+ "temperature = 0.9 # @param {type:\"slider\", min:0.1, max:1, step:0.05}\n",
381
+ "render_MIDI_to_audio = True # @param {type:\"boolean\"}\n",
382
+ "\n",
383
+ "print('=' * 70)\n",
384
+ "print('Melody2Song Seq1Seq Music Transformer Model Generator')\n",
385
+ "print('=' * 70)\n",
386
+ "\n",
387
+ "print('Generating...')\n",
388
+ "print('=' * 70)\n",
389
+ "\n",
390
+ "model.eval()\n",
391
+ "\n",
392
+ "torch.cuda.empty_cache()\n",
393
+ "\n",
394
+ "x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda'))\n",
395
+ "\n",
396
+ "with ctx:\n",
397
+ " out = model.generate(x,\n",
398
+ " number_of_tokens_to_generate,\n",
399
+ " filter_logits_fn=top_k,\n",
400
+ " filter_kwargs={'k': top_k_value},\n",
401
+ " temperature=0.9,\n",
402
+ " return_prime=False,\n",
403
+ " verbose=True)\n",
404
+ "\n",
405
+ "output = out.tolist()\n",
406
+ "\n",
407
+ "print('=' * 70)\n",
408
+ "print('Done!')\n",
409
+ "print('=' * 70)\n",
410
+ "\n",
411
+ "#======================================================================\n",
412
+ "print('Rendering results...')\n",
413
+ "\n",
414
+ "for i in range(number_of_batches_to_generate):\n",
415
+ "\n",
416
+ " print('=' * 70)\n",
417
+ " print('Batch #', i)\n",
418
+ " print('=' * 70)\n",
419
+ "\n",
420
+ " out1 = output[i]\n",
421
+ "\n",
422
+ " print('Sample INTs', out1[:12])\n",
423
+ " print('=' * 70)\n",
424
+ "\n",
425
+ " if len(out1) != 0:\n",
426
+ "\n",
427
+ " song = out1\n",
428
+ " song_f = []\n",
429
+ "\n",
430
+ " time = 0\n",
431
+ " dur = 0\n",
432
+ " vel = 90\n",
433
+ " pitch = 0\n",
434
+ " channel = 0\n",
435
+ "\n",
436
+ " patches = [0] * 16\n",
437
+ " patches[0] = accompaniment_MIDI_patch_number\n",
438
+ " patches[3] = melody_MIDI_patch_number\n",
439
+ "\n",
440
+ " for ss in song:\n",
441
+ "\n",
442
+ " if 0 < ss < 128:\n",
443
+ "\n",
444
+ " time += (ss * 32)\n",
445
+ "\n",
446
+ " if 128 < ss < 256:\n",
447
+ "\n",
448
+ " dur = (ss-128) * 32\n",
449
+ "\n",
450
+ " if 256 < ss < 512:\n",
451
+ "\n",
452
+ " pitch = (ss-256) % 128\n",
453
+ "\n",
454
+ " channel = (ss-256) // 128\n",
455
+ "\n",
456
+ " if channel == 1:\n",
457
+ " channel = 3\n",
458
+ " vel = 110 + (pitch % 12)\n",
459
+ " song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number])\n",
460
+ "\n",
461
+ " else:\n",
462
+ " vel = 80 + (pitch % 12)\n",
463
+ " channel = 0\n",
464
+ " song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number])\n",
465
+ "\n",
466
+ " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n",
467
+ " output_signature = 'Melody2Song Seq2Seq Music Transformer',\n",
468
+ " output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i),\n",
469
+ " track_name='Project Los Angeles',\n",
470
+ " list_of_MIDI_patches=patches\n",
471
+ " )\n",
472
+ " print('=' * 70)\n",
473
+ " print('Displaying resulting composition...')\n",
474
+ " print('=' * 70)\n",
475
+ "\n",
476
+ " fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i)\n",
477
+ "\n",
478
+ " if render_MIDI_to_audio:\n",
479
+ " midi_audio = midi_to_colab_audio(fname + '.mid')\n",
480
+ " display(Audio(midi_audio, rate=16000, normalize=False))\n",
481
+ "\n",
482
+ " TMIDIX.plot_ms_SONG(song_f, plot_title=fname)"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "metadata": {
488
+ "id": "z87TlDTVl5cp"
489
+ },
490
+ "source": [
491
+ "# Congrats! You did it! :)"
492
+ ]
493
+ }
494
+ ],
495
+ "metadata": {
496
+ "accelerator": "GPU",
497
+ "colab": {
498
+ "gpuClass": "premium",
499
+ "gpuType": "L4",
500
+ "private_outputs": true,
501
+ "provenance": [],
502
+ "machine_shape": "hm"
503
+ },
504
+ "kernelspec": {
505
+ "display_name": "Python 3",
506
+ "name": "python3"
507
+ },
508
+ "language_info": {
509
+ "codemirror_mode": {
510
+ "name": "ipython",
511
+ "version": 3
512
+ },
513
+ "file_extension": ".py",
514
+ "mimetype": "text/x-python",
515
+ "name": "python",
516
+ "nbconvert_exporter": "python",
517
+ "pygments_lexer": "ipython3",
518
+ "version": "3.10.12"
519
+ }
520
+ },
521
+ "nbformat": 4,
522
+ "nbformat_minor": 0
523
+ }
melody2song_seq2seq_music_transformer.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Melody2Song_Seq2Seq_Music_Transformer.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1La3iHCib9tluuv4AfsIHCwt1zu0wzl8B
8
+
9
+ # Melody2Song Seq2Seq Music Transformer (ver. 1.0)
10
+
11
+ ***
12
+
13
+ Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools
14
+
15
+ ***
16
+
17
+ WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/
18
+
19
+ ***
20
+
21
+ #### Project Los Angeles
22
+
23
+ #### Tegridy Code 2024
24
+
25
+ ***
26
+
27
+ # (GPU CHECK)
28
+ """
29
+
30
+ # @title NVIDIA GPU Check
31
+ !nvidia-smi
32
+
33
+ """# (SETUP ENVIRONMENT)"""
34
+
35
+ # @title Install requirements
36
+ !git clone --depth 1 https://github.com/asigalov61/tegridy-tools
37
+ !pip install einops
38
+ !pip install torch-summary
39
+ !apt install fluidsynth
40
+
41
+ # Commented out IPython magic to ensure Python compatibility.
42
+ # @title Load all needed modules
43
+
44
+ print('=' * 70)
45
+ print('Loading needed modules...')
46
+ print('=' * 70)
47
+
48
+ import os
49
+ import pickle
50
+ import random
51
+ import secrets
52
+ import tqdm
53
+ import math
54
+ import torch
55
+
56
+ import matplotlib.pyplot as plt
57
+
58
+ from torchsummary import summary
59
+
60
+ # %cd /content/tegridy-tools/tegridy-tools/
61
+
62
+ import TMIDIX
63
+ from midi_to_colab_audio import midi_to_colab_audio
64
+
65
+ # %cd /content/tegridy-tools/tegridy-tools/X-Transformer
66
+
67
+ from x_transformer_1_23_2 import *
68
+
69
+ # %cd /content/
70
+
71
+ import random
72
+
73
+ from sklearn import metrics
74
+
75
+ from IPython.display import Audio, display
76
+
77
+ from huggingface_hub import hf_hub_download
78
+
79
+ from google.colab import files
80
+
81
+ print('=' * 70)
82
+ print('Done')
83
+ print('=' * 70)
84
+ print('Torch version:', torch.__version__)
85
+ print('=' * 70)
86
+ print('Enjoy! :)')
87
+ print('=' * 70)
88
+
89
+ """# (SETUP DATA AND MODEL)"""
90
+
91
+ #@title Load Melody2Song Seq2Seq Music Trnasofmer Data and Pre-Trained Model
92
+
93
+ #@markdown Model precision option
94
+
95
+ model_precision = "bfloat16" # @param ["bfloat16", "float16"]
96
+
97
+ plot_tokens_embeddings = True # @param {type:"boolean"}
98
+
99
+ print('=' * 70)
100
+ print('Donwloading Melody2Song Seq2Seq Music Transformer Data File...')
101
+ print('=' * 70)
102
+
103
+ data_path = '/content'
104
+
105
+ if os.path.isfile(data_path+'/Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle'):
106
+ print('Data file already exists...')
107
+
108
+ else:
109
+ hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',
110
+ repo_type='space',
111
+ filename='Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data.pickle',
112
+ local_dir=data_path,
113
+ )
114
+
115
+ print('=' * 70)
116
+ seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')
117
+
118
+ print('=' * 70)
119
+ print('Loading Melody2Song Seq2Seq Music Transformer Pre-Trained Model...')
120
+ print('Please wait...')
121
+ print('=' * 70)
122
+
123
+ full_path_to_models_dir = "/content"
124
+
125
+ model_checkpoint_file_name = 'Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth'
126
+ model_path = full_path_to_models_dir+'/'+model_checkpoint_file_name
127
+ num_layers = 24
128
+ if os.path.isfile(model_path):
129
+ print('Model already exists...')
130
+
131
+ else:
132
+ hf_hub_download(repo_id='asigalov61/Melody2Song-Seq2Seq-Music-Transformer',
133
+ repo_type='space',
134
+ filename=model_checkpoint_file_name,
135
+ local_dir=full_path_to_models_dir,
136
+ )
137
+
138
+
139
+ print('=' * 70)
140
+ print('Instantiating model...')
141
+
142
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
143
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
144
+ device_type = 'cuda'
145
+
146
+ if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():
147
+ dtype = 'bfloat16'
148
+ else:
149
+ dtype = 'float16'
150
+
151
+ if model_precision == 'float16':
152
+ dtype = 'float16'
153
+
154
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
155
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
156
+
157
+ SEQ_LEN = 2560
158
+ PAD_IDX = 514
159
+
160
+ # instantiate the model
161
+
162
+ model = TransformerWrapper(
163
+ num_tokens = PAD_IDX+1,
164
+ max_seq_len = SEQ_LEN,
165
+ attn_layers = Decoder(dim = 1024, depth = num_layers, heads = 16, attn_flash = True)
166
+ )
167
+
168
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
169
+
170
+ model.cuda()
171
+ print('=' * 70)
172
+
173
+ print('Loading model checkpoint...')
174
+
175
+ model.load_state_dict(torch.load(model_path))
176
+ print('=' * 70)
177
+
178
+ model.eval()
179
+
180
+ print('Done!')
181
+ print('=' * 70)
182
+
183
+ print('Model will use', dtype, 'precision...')
184
+ print('=' * 70)
185
+
186
+ # Model stats
187
+ print('Model summary...')
188
+ summary(model)
189
+
190
+ if plot_tokens_embeddings:
191
+
192
+ tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()
193
+
194
+ cos_sim = metrics.pairwise_distances(
195
+ tok_emb, metric='cosine'
196
+ )
197
+ plt.figure(figsize=(7, 7))
198
+ plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
199
+ im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
200
+ plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
201
+ plt.xlabel("Position")
202
+ plt.ylabel("Position")
203
+ plt.tight_layout()
204
+ plt.plot()
205
+ plt.savefig("/content/Melody2Song-Seq2Seq-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")
206
+
207
+ """# (LOAD SEED MELODY)"""
208
+
209
+ # @title Load desired seed melody
210
+
211
+ #@markdown NOTE: If custom MIDI file is not provided, sample seed melody will be used instead
212
+
213
+ full_path_to_custom_seed_melody_MIDI_file = "/content/tegridy-tools/tegridy-tools/seed-melody.mid" # @param {type:"string"}
214
+ sample_seed_melody_number = 0 # @param {type:"slider", min:0, max:203664, step:1}
215
+
216
+ print('=' * 70)
217
+ print('Loading seed melody...')
218
+ print('=' * 70)
219
+
220
+ if full_path_to_custom_seed_melody_MIDI_file != '':
221
+
222
+ #===============================================================================
223
+ # Raw single-track ms score
224
+
225
+ raw_score = TMIDIX.midi2single_track_ms_score(full_path_to_custom_seed_melody_MIDI_file)
226
+
227
+ #===============================================================================
228
+ # Enhanced score notes
229
+
230
+ escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
231
+
232
+ #===============================================================================
233
+ # Augmented enhanced score notes
234
+
235
+ escore_notes = TMIDIX.recalculate_score_timings(TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32))
236
+
237
+ cscore = TMIDIX.chordify_score([1000, escore_notes])
238
+
239
+ fixed_mel_score = TMIDIX.fix_monophonic_score_durations([c[0] for c in cscore])
240
+
241
+ melody = []
242
+
243
+ pe = fixed_mel_score[0]
244
+
245
+ for s in fixed_mel_score:
246
+
247
+ dtime = max(0, min(127, s[1]-pe[1]))
248
+ dur = max(1, min(127, s[2]))
249
+ ptc = max(1, min(127, s[4]))
250
+
251
+ chan = 1
252
+
253
+ melody.extend([dtime, dur+128, (128 * chan)+ptc+256])
254
+
255
+ pe = s
256
+
257
+ if len(melody) >= 192:
258
+ melody = [512] + melody[:192] + [513]
259
+
260
+ else:
261
+ mult = math.ceil(192 / len(melody))
262
+ melody = melody * mult
263
+ melody = [512] + melody[:192] + [513]
264
+
265
+ print('Loaded custom MIDI melody:', full_path_to_custom_seed_melody_MIDI_file)
266
+ print('=' * 70)
267
+
268
+ else:
269
+ melody = seed_melodies_data[sample_seed_melody_number]
270
+ print('Loaded sample seed melody #', sample_seed_melody_number)
271
+ print('=' * 70)
272
+
273
+ print('Sample melody INTs:', melody[:10])
274
+ print('=' * 70)
275
+ print('Done!')
276
+ print('=' * 70)
277
+
278
+ """# (GENERATE)"""
279
+
280
+ # @title Generate song from melody
281
+
282
+ melody_MIDI_patch_number = 40 # @param {type:"slider", min:0, max:127, step:1}
283
+ accompaniment_MIDI_patch_number = 0 # @param {type:"slider", min:0, max:127, step:1}
284
+ number_of_tokens_to_generate = 900 # @param {type:"slider", min:15, max:2354, step:3}
285
+ number_of_batches_to_generate = 4 # @param {type:"slider", min:1, max:16, step:1}
286
+ top_k_value = 25 # @param {type:"slider", min:1, max:50, step:1}
287
+ temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05}
288
+ render_MIDI_to_audio = True # @param {type:"boolean"}
289
+
290
+ print('=' * 70)
291
+ print('Melody2Song Seq1Seq Music Transformer Model Generator')
292
+ print('=' * 70)
293
+
294
+ print('Generating...')
295
+ print('=' * 70)
296
+
297
+ model.eval()
298
+
299
+ torch.cuda.empty_cache()
300
+
301
+ x = (torch.tensor([melody] * number_of_batches_to_generate, dtype=torch.long, device='cuda'))
302
+
303
+ with ctx:
304
+ out = model.generate(x,
305
+ number_of_tokens_to_generate,
306
+ filter_logits_fn=top_k,
307
+ filter_kwargs={'k': top_k_value},
308
+ temperature=0.9,
309
+ return_prime=False,
310
+ verbose=True)
311
+
312
+ output = out.tolist()
313
+
314
+ print('=' * 70)
315
+ print('Done!')
316
+ print('=' * 70)
317
+
318
+ #======================================================================
319
+ print('Rendering results...')
320
+
321
+ for i in range(number_of_batches_to_generate):
322
+
323
+ print('=' * 70)
324
+ print('Batch #', i)
325
+ print('=' * 70)
326
+
327
+ out1 = output[i]
328
+
329
+ print('Sample INTs', out1[:12])
330
+ print('=' * 70)
331
+
332
+ if len(out1) != 0:
333
+
334
+ song = out1
335
+ song_f = []
336
+
337
+ time = 0
338
+ dur = 0
339
+ vel = 90
340
+ pitch = 0
341
+ channel = 0
342
+
343
+ patches = [0] * 16
344
+ patches[0] = accompaniment_MIDI_patch_number
345
+ patches[3] = melody_MIDI_patch_number
346
+
347
+ for ss in song:
348
+
349
+ if 0 < ss < 128:
350
+
351
+ time += (ss * 32)
352
+
353
+ if 128 < ss < 256:
354
+
355
+ dur = (ss-128) * 32
356
+
357
+ if 256 < ss < 512:
358
+
359
+ pitch = (ss-256) % 128
360
+
361
+ channel = (ss-256) // 128
362
+
363
+ if channel == 1:
364
+ channel = 3
365
+ vel = 110 + (pitch % 12)
366
+ song_f.append(['note', time, dur, channel, pitch, vel, melody_MIDI_patch_number])
367
+
368
+ else:
369
+ vel = 80 + (pitch % 12)
370
+ channel = 0
371
+ song_f.append(['note', time, dur, channel, pitch, vel, accompaniment_MIDI_patch_number])
372
+
373
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
374
+ output_signature = 'Melody2Song Seq2Seq Music Transformer',
375
+ output_file_name = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i),
376
+ track_name='Project Los Angeles',
377
+ list_of_MIDI_patches=patches
378
+ )
379
+ print('=' * 70)
380
+ print('Displaying resulting composition...')
381
+ print('=' * 70)
382
+
383
+ fname = '/content/Melody2Song-Seq2Seq-Music-Transformer-Composition_'+str(i)
384
+
385
+ if render_MIDI_to_audio:
386
+ midi_audio = midi_to_colab_audio(fname + '.mid')
387
+ display(Audio(midi_audio, rate=16000, normalize=False))
388
+
389
+ TMIDIX.plot_ms_SONG(song_f, plot_title=fname)
390
+
391
+ """# Congrats! You did it! :)"""