Update app.py
Browse files
app.py
CHANGED
@@ -400,18 +400,18 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
|
|
400 |
print('=' * 70)
|
401 |
print('Giant Music Transformer MIDI Comparator')
|
402 |
print('=' * 70)
|
403 |
-
|
404 |
-
input_src_tokens = src_tokens
|
405 |
-
input_trg_tokens = trg_tokens
|
406 |
-
|
407 |
sampling_resolution = max(40, min(1000, input_sampling_resolution)) * 3
|
408 |
sampling_overlap = max(0, min(500, input_sampling_overlap)) * 3
|
409 |
|
410 |
comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
|
|
|
|
|
|
|
411 |
|
412 |
comp_cos_sims = []
|
413 |
|
414 |
-
for i in range(0, comp_length
|
415 |
|
416 |
torch.cuda.empty_cache()
|
417 |
|
@@ -437,7 +437,6 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
|
|
437 |
cache = out[2]
|
438 |
trg_embedings = cache.layer_hiddens[-1]
|
439 |
|
440 |
-
|
441 |
cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
|
442 |
[trg_embedings.cpu().detach().numpy()[0].flatten()]
|
443 |
).tolist()[0][0]
|
|
|
400 |
print('=' * 70)
|
401 |
print('Giant Music Transformer MIDI Comparator')
|
402 |
print('=' * 70)
|
403 |
+
|
|
|
|
|
|
|
404 |
sampling_resolution = max(40, min(1000, input_sampling_resolution)) * 3
|
405 |
sampling_overlap = max(0, min(500, input_sampling_overlap)) * 3
|
406 |
|
407 |
comp_length = (min(len(input_src_tokens), len(input_trg_tokens)) // sampling_resolution) * sampling_resolution
|
408 |
+
|
409 |
+
input_src_tokens = src_tokens[:comp_length]
|
410 |
+
input_trg_tokens = trg_tokens[:comp_length]
|
411 |
|
412 |
comp_cos_sims = []
|
413 |
|
414 |
+
for i in range(0, comp_length, max(1, sampling_resolution-sampling_overlap)):
|
415 |
|
416 |
torch.cuda.empty_cache()
|
417 |
|
|
|
437 |
cache = out[2]
|
438 |
trg_embedings = cache.layer_hiddens[-1]
|
439 |
|
|
|
440 |
cos_sim = pairwise.cosine_similarity([src_embedings.cpu().detach().numpy()[0].flatten()],
|
441 |
[trg_embedings.cpu().detach().numpy()[0].flatten()]
|
442 |
).tolist()[0][0]
|