asigalov61 commited on
Commit
9324c3c
·
verified ·
1 Parent(s): c8d9e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateDrums(input_midi, input_num_tokens):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
@@ -74,6 +74,7 @@ def GenerateDrums(input_midi, input_num_tokens):
74
  print('-' * 70)
75
  print('Input file name:', fn)
76
  print('Req num toks:', input_num_tokens)
 
77
  print('-' * 70)
78
 
79
  #===============================================================================
@@ -138,6 +139,8 @@ def GenerateDrums(input_midi, input_num_tokens):
138
  with ctx:
139
  out = model.generate(x[-num_memory_tokens:],
140
  1,
 
 
141
  temperature=temperature,
142
  return_prime=False,
143
  verbose=False)
@@ -287,6 +290,7 @@ if __name__ == "__main__":
287
 
288
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
289
  input_num_tokens = gr.Slider(16, 2048, value=256, step=16, label="Number of composition chords to generate drums for")
 
290
 
291
  run_btn = gr.Button("generate", variant="primary")
292
 
@@ -299,21 +303,21 @@ if __name__ == "__main__":
299
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
300
 
301
 
302
- run_event = run_btn.click(GenerateDrums, [input_midi, input_num_tokens],
303
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
304
 
305
  gr.Examples(
306
- [["Ultimate-Drums-Transformer-Melody-Seed-1.mid", 128],
307
- ["Ultimate-Drums-Transformer-Melody-Seed-2.mid", 128],
308
- ["Ultimate-Drums-Transformer-Melody-Seed-3.mid", 128],
309
- ["Ultimate-Drums-Transformer-Melody-Seed-4.mid", 128],
310
- ["Ultimate-Drums-Transformer-Melody-Seed-5.mid", 128],
311
- ["Ultimate-Drums-Transformer-Melody-Seed-6.mid", 128],
312
- ["Ultimate-Drums-Transformer-MI-Seed-1.mid", 128],
313
- ["Ultimate-Drums-Transformer-MI-Seed-2.mid", 128],
314
- ["Ultimate-Drums-Transformer-MI-Seed-3.mid", 128],
315
- ["Ultimate-Drums-Transformer-MI-Seed-4.mid", 128]],
316
- [input_midi, input_num_tokens],
317
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
318
  GenerateDrums,
319
  cache_examples=True,
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def GenerateDrums(input_midi, input_num_tokens, input_top_k_value):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
 
74
  print('-' * 70)
75
  print('Input file name:', fn)
76
  print('Req num toks:', input_num_tokens)
77
+ print('Req top_k value:', input_top_k_value)
78
  print('-' * 70)
79
 
80
  #===============================================================================
 
139
  with ctx:
140
  out = model.generate(x[-num_memory_tokens:],
141
  1,
142
+ filter_logits_fn=top_k,
143
+ filter_kwargs={'k': input_top_k_value},
144
  temperature=temperature,
145
  return_prime=False,
146
  verbose=False)
 
290
 
291
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
292
  input_num_tokens = gr.Slider(16, 2048, value=256, step=16, label="Number of composition chords to generate drums for")
293
+ input_top_k_value = gr.Slider(1, 50, value=1, step=1, label="Model sampling top_k value")
294
 
295
  run_btn = gr.Button("generate", variant="primary")
296
 
 
303
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
304
 
305
 
306
+ run_event = run_btn.click(GenerateDrums, [input_midi, input_num_tokens, input_top_k_value],
307
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
308
 
309
  gr.Examples(
310
+ [["Ultimate-Drums-Transformer-Melody-Seed-1.mid", 128, 1],
311
+ ["Ultimate-Drums-Transformer-Melody-Seed-2.mid", 128, 1],
312
+ ["Ultimate-Drums-Transformer-Melody-Seed-3.mid", 128, 1],
313
+ ["Ultimate-Drums-Transformer-Melody-Seed-4.mid", 128, 1],
314
+ ["Ultimate-Drums-Transformer-Melody-Seed-5.mid", 128, 1],
315
+ ["Ultimate-Drums-Transformer-Melody-Seed-6.mid", 128, 1],
316
+ ["Ultimate-Drums-Transformer-MI-Seed-1.mid", 128, 1],
317
+ ["Ultimate-Drums-Transformer-MI-Seed-2.mid", 128, 1],
318
+ ["Ultimate-Drums-Transformer-MI-Seed-3.mid", 128, 1],
319
+ ["Ultimate-Drums-Transformer-MI-Seed-4.mid", 128, 1]],
320
+ [input_midi, input_num_tokens, input_top_k_value],
321
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
322
  GenerateDrums,
323
  cache_examples=True,