asigalov61 commited on
Commit
763c372
·
verified ·
1 Parent(s): 8e1aa79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -400,18 +400,18 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
400
  print('=' * 70)
401
  print('Giant Music Transformer MIDI Comparator')
402
  print('=' * 70)
403
-
404
- input_src_tokens = src_tokens
405
- input_trg_tokens = trg_tokens
406
-
407
  sampling_resolution = max(40, min(1000, input_sampling_resolution)) * 3
408
  sampling_overlap = max(0, min(500, input_sampling_overlap)) * 3
409
 
410
  comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
 
 
 
411
 
412
  comp_cos_sims = []
413
 
414
- for i in range(0, comp_length-sampling_resolution-sampling_overlap), max(1, sampling_resolution-sampling_overlap)):
415
 
416
  torch.cuda.empty_cache()
417
 
@@ -437,7 +437,6 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
437
  cache = out[2]
438
  trg_embedings = cache.layer_hiddens[-1]
439
 
440
-
441
  cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
442
  [trg_embedings.cpu().detach().numpy()[0].flatten()]
443
  ).tolist()[0][0]
 
400
  print('=' * 70)
401
  print('Giant Music Transformer MIDI Comparator')
402
  print('=' * 70)
403
+
 
 
 
404
  sampling_resolution = max(40, min(1000, input_sampling_resolution)) * 3
405
  sampling_overlap = max(0, min(500, input_sampling_overlap)) * 3
406
 
407
  comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
408
+
409
+ input_src_tokens = src_tokens[:comp_length]
410
+ input_trg_tokens = trg_tokens[:comp_length]
411
 
412
  comp_cos_sims = []
413
 
414
+ for i in range(0, comp_length, max(1, sampling_resolution-sampling_overlap)):
415
 
416
  torch.cuda.empty_cache()
417
 
 
437
  cache = out[2]
438
  trg_embedings = cache.layer_hiddens[-1]
439
 
 
440
  cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
441
  [trg_embedings.cpu().detach().numpy()[0].flatten()]
442
  ).tolist()[0][0]