asigalov61 commited on
Commit
5b70622
·
verified ·
1 Parent(s): 6521243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -39,7 +39,7 @@ def GenerateDrums(input_midi, input_num_tokens):
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 1024, depth = 4, heads = 16, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
@@ -50,7 +50,7 @@ def GenerateDrums(input_midi, input_num_tokens):
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
- torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER4_RST_VEL_4L_16534_steps_0.4074_loss_0.8631_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
@@ -146,7 +146,7 @@ def GenerateDrums(input_midi, input_num_tokens):
146
  time += (o-128)
147
  ncount = 0
148
 
149
- if 256 <= o < 384:
150
  ncount += 1
151
 
152
  if o > 127 and time < ntime:
 
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(dim = 1024, depth = 8, heads = 16, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
 
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
+ torch.load('Ultimate_Drums_Transformer_Small_Trained_Model_VER4_RST_VEL_8L_13501_steps_0.3341_loss_0.8893_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
 
146
  time += (o-128)
147
  ncount = 0
148
 
149
+ if 256 < o < 384:
150
  ncount += 1
151
 
152
  if o > 127 and time < ntime: