Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -39,7 +39,7 @@ def GenerateDrums(input_midi, input_num_tokens):
|
|
39 |
model = TransformerWrapper(
|
40 |
num_tokens = PAD_IDX+1,
|
41 |
max_seq_len = SEQ_LEN,
|
42 |
-
attn_layers = Decoder(dim = 1024, depth =
|
43 |
)
|
44 |
|
45 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
@@ -50,7 +50,7 @@ def GenerateDrums(input_midi, input_num_tokens):
|
|
50 |
print('Loading model checkpoint...')
|
51 |
|
52 |
model.load_state_dict(
|
53 |
-
torch.load('
|
54 |
map_location=DEVICE))
|
55 |
print('=' * 70)
|
56 |
|
@@ -146,7 +146,7 @@ def GenerateDrums(input_midi, input_num_tokens):
|
|
146 |
time += (o-128)
|
147 |
ncount = 0
|
148 |
|
149 |
-
if 256
|
150 |
ncount += 1
|
151 |
|
152 |
if o > 127 and time < ntime:
|
|
|
39 |
model = TransformerWrapper(
|
40 |
num_tokens = PAD_IDX+1,
|
41 |
max_seq_len = SEQ_LEN,
|
42 |
+
attn_layers = Decoder(dim = 1024, depth = 8, heads = 16, attn_flash = True)
|
43 |
)
|
44 |
|
45 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
|
|
50 |
print('Loading model checkpoint...')
|
51 |
|
52 |
model.load_state_dict(
|
53 |
+
torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER4_RST_VEL_8L_13501_steps_0.3341_loss_0.8893_acc.pth',
|
54 |
map_location=DEVICE))
|
55 |
print('=' * 70)
|
56 |
|
|
|
146 |
time += (o-128)
|
147 |
ncount = 0
|
148 |
|
149 |
+
if 256 < o < 384:
|
150 |
ncount += 1
|
151 |
|
152 |
if o > 127 and time < ntime:
|