asigalov61 commited on
Commit
8ac01e3
·
verified ·
1 Parent(s): 7618549

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -49
app.py CHANGED
@@ -62,44 +62,6 @@ def GenerateDrums(input_midi, input_num_tokens):
62
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
63
  start_time = reqtime.time()
64
 
65
- print('Loading model...')
66
-
67
- SEQ_LEN = 8192 # Models seq len
68
- PAD_IDX = 393 # Models pad index
69
- DEVICE = 'cuda' # 'cuda'
70
-
71
- # instantiate the model
72
-
73
- model = TransformerWrapper(
74
- num_tokens = PAD_IDX+1,
75
- max_seq_len = SEQ_LEN,
76
- attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
77
- )
78
-
79
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
80
-
81
- model.to(DEVICE)
82
- print('=' * 70)
83
-
84
- print('Loading model checkpoint...')
85
-
86
- model.load_state_dict(
87
- torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
88
- map_location=DEVICE))
89
- print('=' * 70)
90
-
91
- model.eval()
92
-
93
- if DEVICE == 'cpu':
94
- dtype = torch.bfloat16
95
- else:
96
- dtype = torch.float16
97
-
98
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
99
-
100
- print('Done!')
101
- print('=' * 70)
102
-
103
  fn = os.path.basename(input_midi.name)
104
  fn1 = fn.split('.')[0]
105
 
@@ -277,6 +239,44 @@ if __name__ == "__main__":
277
  opt = parser.parse_args()
278
 
279
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  app = gr.Blocks()
282
  with app:
@@ -310,20 +310,20 @@ if __name__ == "__main__":
310
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
311
 
312
  gr.Examples(
313
- [["Ultimate-Drums-Transformer-Melody-Seed-1.mid", 16],
314
- ["Ultimate-Drums-Transformer-Melody-Seed-2.mid", 16],
315
- ["Ultimate-Drums-Transformer-Melody-Seed-3.mid", 16],
316
- ["Ultimate-Drums-Transformer-Melody-Seed-4.mid", 16],
317
- ["Ultimate-Drums-Transformer-Melody-Seed-5.mid", 16],
318
- ["Ultimate-Drums-Transformer-Melody-Seed-6.mid", 16],
319
- ["Ultimate-Drums-Transformer-MI-Seed-1.mid", 16],
320
- ["Ultimate-Drums-Transformer-MI-Seed-2.mid", 16],
321
- ["Ultimate-Drums-Transformer-MI-Seed-3.mid", 16],
322
- ["Ultimate-Drums-Transformer-MI-Seed-4.mid", 16]],
323
  [input_midi, input_num_tokens],
324
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
325
  GenerateDrums,
326
- cache_examples=False,
327
  )
328
 
329
  app.queue().launch()
 
62
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
63
  start_time = reqtime.time()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  fn = os.path.basename(input_midi.name)
66
  fn1 = fn.split('.')[0]
67
 
 
239
  opt = parser.parse_args()
240
 
241
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
242
+
243
+ print('Loading model...')
244
+
245
+ SEQ_LEN = 8192 # Models seq len
246
+ PAD_IDX = 393 # Models pad index
247
+ DEVICE = 'cuda' # 'cuda'
248
+
249
+ # instantiate the model
250
+
251
+ model = TransformerWrapper(
252
+ num_tokens = PAD_IDX+1,
253
+ max_seq_len = SEQ_LEN,
254
+ attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
255
+ )
256
+
257
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
258
+
259
+ model.to(DEVICE)
260
+ print('=' * 70)
261
+
262
+ print('Loading model checkpoint...')
263
+
264
+ model.load_state_dict(
265
+ torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
266
+ map_location=DEVICE))
267
+ print('=' * 70)
268
+
269
+ model.eval()
270
+
271
+ if DEVICE == 'cpu':
272
+ dtype = torch.bfloat16
273
+ else:
274
+ dtype = torch.float16
275
+
276
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
277
+
278
+ print('Done!')
279
+ print('=' * 70)
280
 
281
  app = gr.Blocks()
282
  with app:
 
310
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
311
 
312
  gr.Examples(
313
+ [["Ultimate-Drums-Transformer-Melody-Seed-1.mid", 128],
314
+ ["Ultimate-Drums-Transformer-Melody-Seed-2.mid", 128],
315
+ ["Ultimate-Drums-Transformer-Melody-Seed-3.mid", 128],
316
+ ["Ultimate-Drums-Transformer-Melody-Seed-4.mid", 128],
317
+ ["Ultimate-Drums-Transformer-Melody-Seed-5.mid", 128],
318
+ ["Ultimate-Drums-Transformer-Melody-Seed-6.mid", 128],
319
+ ["Ultimate-Drums-Transformer-MI-Seed-1.mid", 128],
320
+ ["Ultimate-Drums-Transformer-MI-Seed-2.mid", 128],
321
+ ["Ultimate-Drums-Transformer-MI-Seed-3.mid", 128],
322
+ ["Ultimate-Drums-Transformer-MI-Seed-4.mid", 128]],
323
  [input_midi, input_num_tokens],
324
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
325
  GenerateDrums,
326
+ cache_examples=True,
327
  )
328
 
329
  app.queue().launch()