asigalov61's picture
Update app.py
8abf2d0 verified
raw
history blame
8.6 kB
import os.path
import time as reqtime
import datetime
from pytz import timezone
import torch
import spaces
import gradio as gr
from x_transformer_1_23_2 import *
import random
import tqdm
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
in_space = os.getenv("SYSTEM") == "spaces"
# =================================================================================================
@spaces.GPU
def GenerateGroove():
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('Loading model...')
SEQ_LEN = 4096 # Models seq len
PAD_IDX = 1664 # Models pad index
DEVICE = 'cuda' # 'cuda'
# instantiate the model
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
)
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
model.to(DEVICE)
print('=' * 70)
print('Loading model checkpoint...')
model.load_state_dict(
torch.load('Groove_Music_Transformer_Medium_Trained_Model_23268_steps_0.7459_loss_0.797_acc.pth',
map_location=DEVICE))
print('=' * 70)
model.eval()
if DEVICE == 'cpu':
dtype = torch.bfloat16
else:
dtype = torch.float16
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
print('Done!')
print('=' * 70)
print('Loading Google Magenta Groove processed MIDIs...')
all_scores = TMIDIX.Tegridy_Any_Pickle_File_Reader('Google_Magenta_Groove_8675_Select_Processed_MIDIs')
print('Done!')
print('=' * 70)
print('=' * 70)
drums_score_idx = random.randint(0, len(all_scores))
drums_score_fn = all_scores[drums_score_idx][0]
drums_score = all_scores[drums_score_idx][1]
print('Drums score index', drums_score_idx)
print('Drums score name', drums_score_fn)
print('Drums score length', len(drums_score))
print('=' * 70)
#==================================================================
print('=' * 70)
print('Sample input events', drums_score[:5])
print('=' * 70)
print('Prepping drums track...')
num_prime_chords = 7
outy = []
for d in drums_score[:num_prime_chords]:
outy.extend(d)
print('Generating...')
max_notes_per_chord=8
num_samples=4
num_memory_tokens = 4096
temperature=1.0
for i in range(num_prime_chords, len(drums_score)):
outy.extend(drums_score[i])
if i == num_prime_chords:
outy.append(256+12)
input_seq = outy[-num_memory_tokens:]
seq = copy.deepcopy(input_seq)
batch_value = 256
nc = 0
while batch_value > 255 and nc < max_notes_per_chord:
x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda')
with ctx:
out = model.generate(x,
1,
temperature=temperature,
return_prime=False,
verbose=False)
out1 = [o[0] for o in out.tolist() if o[0] > 255]
if not out1:
out1 = [-1]
batch_value = random.choice(out1)
if batch_value > 255:
seq.append(batch_value)
if batch_value > 383:
nc += 1
out = seq[len(input_seq):]
outy.extend(out)
print('=' * 70)
print('Done!')
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
print('Sample INTs', outy[:12])
print('=' * 70)
if len(outy) != 0:
song = outy
song_f = []
time = 0
dur = 32
vel = 90
dvels = [100, 120]
pitch = 60
channel = 0
patches = [0, 10, 19, 24, 35, 40, 52, 56, 65, 9, 73, 0, 0, 0, 0, 0]
for ss in song:
if 0 <= ss < 128:
time += ss * 32
if 128 <= ss < 256:
song_f.append(['note', time, 32, 9, ss-128, dvels[(ss-128) % 2], 128])
if 256 < ss < 384:
dur = (ss-256) * 32
if 384 < ss < 1664:
chan = (ss-384) // 128
if chan == 11:
channel = 9
else:
if chan > 8:
channel = chan + 1
else:
channel = chan
if channel == 9:
patch = 128
else:
patch = channel * 8
pitch = (ss-384) % 128
vel = max(50, pitch)
song_f.append(['note', time, dur, channel, pitch, vel, patch])
fn1 = drums_score_fn
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
output_signature = 'Groove Music Transformer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1)
output_midi_summary = str(song_f[:3])
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', '')
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Groove Music Transformer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate music for Google Magenta Groove MIDI dataset drums tracks</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Groove-Music-Transformer&style=flat)\n\n"
"Generate music for Google Magenta Groove MIDI dataset drums tracks\n\n"
"Based upon [Google Magenta Groove MIDI Dataset](https://magenta.tensorflow.org/datasets/groove)\n\n"
)
run_btn = gr.Button("generate groove", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(GenerateGroove,
outputs=[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
app.queue().launch()