Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -62,44 +62,6 @@ def GenerateDrums(input_midi, input_num_tokens):
|
|
62 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
63 |
start_time = reqtime.time()
|
64 |
|
65 |
-
print('Loading model...')
|
66 |
-
|
67 |
-
SEQ_LEN = 8192 # Models seq len
|
68 |
-
PAD_IDX = 393 # Models pad index
|
69 |
-
DEVICE = 'cuda' # 'cuda'
|
70 |
-
|
71 |
-
# instantiate the model
|
72 |
-
|
73 |
-
model = TransformerWrapper(
|
74 |
-
num_tokens = PAD_IDX+1,
|
75 |
-
max_seq_len = SEQ_LEN,
|
76 |
-
attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
|
77 |
-
)
|
78 |
-
|
79 |
-
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
80 |
-
|
81 |
-
model.to(DEVICE)
|
82 |
-
print('=' * 70)
|
83 |
-
|
84 |
-
print('Loading model checkpoint...')
|
85 |
-
|
86 |
-
model.load_state_dict(
|
87 |
-
torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
|
88 |
-
map_location=DEVICE))
|
89 |
-
print('=' * 70)
|
90 |
-
|
91 |
-
model.eval()
|
92 |
-
|
93 |
-
if DEVICE == 'cpu':
|
94 |
-
dtype = torch.bfloat16
|
95 |
-
else:
|
96 |
-
dtype = torch.float16
|
97 |
-
|
98 |
-
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
99 |
-
|
100 |
-
print('Done!')
|
101 |
-
print('=' * 70)
|
102 |
-
|
103 |
fn = os.path.basename(input_midi.name)
|
104 |
fn1 = fn.split('.')[0]
|
105 |
|
@@ -277,6 +239,44 @@ if __name__ == "__main__":
|
|
277 |
opt = parser.parse_args()
|
278 |
|
279 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
app = gr.Blocks()
|
282 |
with app:
|
@@ -310,20 +310,20 @@ if __name__ == "__main__":
|
|
310 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
311 |
|
312 |
gr.Examples(
|
313 |
-
[["Ultimate-Drums-Transformer-Melody-Seed-1.mid",
|
314 |
-
["Ultimate-Drums-Transformer-Melody-Seed-2.mid",
|
315 |
-
["Ultimate-Drums-Transformer-Melody-Seed-3.mid",
|
316 |
-
["Ultimate-Drums-Transformer-Melody-Seed-4.mid",
|
317 |
-
["Ultimate-Drums-Transformer-Melody-Seed-5.mid",
|
318 |
-
["Ultimate-Drums-Transformer-Melody-Seed-6.mid",
|
319 |
-
["Ultimate-Drums-Transformer-MI-Seed-1.mid",
|
320 |
-
["Ultimate-Drums-Transformer-MI-Seed-2.mid",
|
321 |
-
["Ultimate-Drums-Transformer-MI-Seed-3.mid",
|
322 |
-
["Ultimate-Drums-Transformer-MI-Seed-4.mid",
|
323 |
[input_midi, input_num_tokens],
|
324 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
|
325 |
GenerateDrums,
|
326 |
-
cache_examples=
|
327 |
)
|
328 |
|
329 |
app.queue().launch()
|
|
|
62 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
63 |
start_time = reqtime.time()
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
fn = os.path.basename(input_midi.name)
|
66 |
fn1 = fn.split('.')[0]
|
67 |
|
|
|
239 |
opt = parser.parse_args()
|
240 |
|
241 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
242 |
+
|
243 |
+
print('Loading model...')
|
244 |
+
|
245 |
+
SEQ_LEN = 8192 # Models seq len
|
246 |
+
PAD_IDX = 393 # Models pad index
|
247 |
+
DEVICE = 'cuda' # 'cuda'
|
248 |
+
|
249 |
+
# instantiate the model
|
250 |
+
|
251 |
+
model = TransformerWrapper(
|
252 |
+
num_tokens = PAD_IDX+1,
|
253 |
+
max_seq_len = SEQ_LEN,
|
254 |
+
attn_layers = Decoder(dim = 1024, depth = 4, heads = 8, attn_flash = True)
|
255 |
+
)
|
256 |
+
|
257 |
+
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
258 |
+
|
259 |
+
model.to(DEVICE)
|
260 |
+
print('=' * 70)
|
261 |
+
|
262 |
+
print('Loading model checkpoint...')
|
263 |
+
|
264 |
+
model.load_state_dict(
|
265 |
+
torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER3_VEL_11222_steps_0.5749_loss_0.8085_acc.pth',
|
266 |
+
map_location=DEVICE))
|
267 |
+
print('=' * 70)
|
268 |
+
|
269 |
+
model.eval()
|
270 |
+
|
271 |
+
if DEVICE == 'cpu':
|
272 |
+
dtype = torch.bfloat16
|
273 |
+
else:
|
274 |
+
dtype = torch.float16
|
275 |
+
|
276 |
+
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
277 |
+
|
278 |
+
print('Done!')
|
279 |
+
print('=' * 70)
|
280 |
|
281 |
app = gr.Blocks()
|
282 |
with app:
|
|
|
310 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
311 |
|
312 |
gr.Examples(
|
313 |
+
[["Ultimate-Drums-Transformer-Melody-Seed-1.mid", 128],
|
314 |
+
["Ultimate-Drums-Transformer-Melody-Seed-2.mid", 128],
|
315 |
+
["Ultimate-Drums-Transformer-Melody-Seed-3.mid", 128],
|
316 |
+
["Ultimate-Drums-Transformer-Melody-Seed-4.mid", 128],
|
317 |
+
["Ultimate-Drums-Transformer-Melody-Seed-5.mid", 128],
|
318 |
+
["Ultimate-Drums-Transformer-Melody-Seed-6.mid", 128],
|
319 |
+
["Ultimate-Drums-Transformer-MI-Seed-1.mid", 128],
|
320 |
+
["Ultimate-Drums-Transformer-MI-Seed-2.mid", 128],
|
321 |
+
["Ultimate-Drums-Transformer-MI-Seed-3.mid", 128],
|
322 |
+
["Ultimate-Drums-Transformer-MI-Seed-4.mid", 128]],
|
323 |
[input_midi, input_num_tokens],
|
324 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
|
325 |
GenerateDrums,
|
326 |
+
cache_examples=True,
|
327 |
)
|
328 |
|
329 |
app.queue().launch()
|