import os.path import time as reqtime import datetime from pytz import timezone import torch import spaces import gradio as gr from x_transformer_1_23_2 import * import random import copy import tqdm from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt in_space = os.getenv("SYSTEM") == "spaces" # ================================================================================================= @spaces.GPU def GenerateGroove(): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('Loading model...') SEQ_LEN = 4096 # Models seq len PAD_IDX = 1664 # Models pad index DEVICE = 'cuda' # 'cuda' # instantiate the model model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True) ) model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) model.to(DEVICE) print('=' * 70) print('Loading model checkpoint...') model.load_state_dict( torch.load('Groove_Music_Transformer_Medium_Trained_Model_23268_steps_0.7459_loss_0.797_acc.pth', map_location=DEVICE)) print('=' * 70) model.eval() if DEVICE == 'cpu': dtype = torch.bfloat16 else: dtype = torch.float16 ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) print('Done!') print('=' * 70) print('Loading Google Magenta Groove processed MIDIs...') all_scores = TMIDIX.Tegridy_Any_Pickle_File_Reader('Google_Magenta_Groove_8675_Select_Processed_MIDIs') print('Done!') print('=' * 70) print('=' * 70) drums_score_idx = random.randint(0, len(all_scores)) drums_score_fn = all_scores[drums_score_idx][0] drums_score = all_scores[drums_score_idx][1][:160] print('Drums score index', drums_score_idx) print('Drums score name', drums_score_fn) print('Drums score length', len(drums_score)) print('=' * 70) #================================================================== print('=' * 70) print('Sample input events', drums_score[:5]) print('=' * 70) print('Prepping drums track...') num_prime_chords = 7 outy = [] for d in drums_score[:num_prime_chords]: outy.extend(d) print('Generating...') max_notes_per_chord=8 num_samples=4 num_memory_tokens = 4096 temperature=1.0 for i in range(num_prime_chords, len(drums_score)): outy.extend(drums_score[i]) if i == num_prime_chords: outy.append(256+12) input_seq = outy[-num_memory_tokens:] seq = copy.deepcopy(input_seq) batch_value = 256 nc = 0 while batch_value > 255 and nc < max_notes_per_chord: x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda') with ctx: out = model.generate(x, 1, temperature=temperature, return_prime=False, verbose=False) out1 = [o[0] for o in out.tolist() if o[0] > 255] if not out1: out1 = [-1] batch_value = random.choice(out1) if batch_value > 255: seq.append(batch_value) if batch_value > 383: nc += 1 out = seq[len(input_seq):] outy.extend(out) print('=' * 70) print('Done!') print('=' * 70) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', outy[:12]) print('=' * 70) if len(outy) != 0: song = outy song_f = [] time = 0 dur = 32 vel = 90 dvels = [100, 120] pitch = 60 channel = 0 patches = [0, 10, 19, 24, 35, 40, 52, 56, 65, 9, 73, 0, 0, 0, 0, 0] for ss in song: if 0 <= ss < 128: time += ss * 32 if 128 <= ss < 256: song_f.append(['note', time, 32, 9, ss-128, dvels[(ss-128) % 2], 128]) if 256 < ss < 384: dur = (ss-256) * 32 if 384 < ss < 1664: chan = (ss-384) // 128 if chan == 11: channel = 9 else: if chan > 8: channel = chan + 1 else: channel = chan if channel == 9: patch = 128 else: patch = channel * 8 pitch = (ss-384) % 128 vel = max(50, pitch) song_f.append(['note', time, dur, channel, pitch, vel, patch]) fn1 = drums_score_fn detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Groove Music Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(song_f[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', '') print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("

Groove Music Transformer

") gr.Markdown("

Generate music for Google Magenta Groove MIDI dataset drums tracks

") gr.Markdown( "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Groove-Music-Transformer&style=flat)\n\n" "Generate music for Google Magenta Groove MIDI dataset drums tracks\n\n" "Based upon [Google Magenta Groove MIDI Dataset](https://magenta.tensorflow.org/datasets/groove)\n\n" ) run_btn = gr.Button("generate groove", variant="primary") gr.Markdown("## Generation results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Output MIDI summary") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(GenerateGroove, outputs=[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) app.queue().launch()