YourMT3 / amt /src /tests /metrics_test.py
mimbres's picture
.
a03c9b4
raw
history blame
4.69 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""metrics_test.py:
This file contains tests for the following classes:
• AMTMetrics
"""
import unittest
import warnings
import torch
import numpy as np
from utils.metrics import AMTMetrics
from utils.metrics import compute_track_metrics
class TestAMTMetrics(unittest.TestCase):
def test_individual_attributes(self):
metric = AMTMetrics()
# Test updating the metric using .update() method
metric.onset_f.update(0.5)
# Test updating the metric using __call__() method
metric.onset_f(0.5)
# Test updating the metric with a weight
metric.onset_f(0, weight=1.0)
# Test computing the average value of the metric
computed_value = metric.onset_f.compute()
self.assertAlmostEqual(computed_value, 0.3333333333333333)
# Test resetting the metric
metric.onset_f.reset()
with self.assertWarns(UserWarning):
torch._assert(metric.onset_f.compute(), torch.nan)
# Test bulk_compute
with self.assertWarns(UserWarning):
computed_metrics = metric.bulk_compute()
def test_bulk_update_and_compute(self):
metric = AMTMetrics()
# Test bulk_update with values only
d1 = {'onset_f': 0.5, 'offset_f': 0.5}
metric.bulk_update(d1)
# Test bulk_update with values and weights
d2 = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}}
metric.bulk_update(d2)
# Test bulk_compute
computed_metrics = metric.bulk_compute()
# Ensure the 'onset_f' and 'offset_f' keys exist in the computed_metrics dictionary
self.assertIn('onset_f', computed_metrics)
self.assertIn('offset_f', computed_metrics)
# Check the computed values
self.assertAlmostEqual(computed_metrics['onset_f'], 0.5)
self.assertAlmostEqual(computed_metrics['offset_f'], 0.5)
def test_compute_track_metrics_singing(self):
from config.vocabulary import SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS
from utils.event2note import note_event2note
ref_notes_dict = np.load('extras/examples/singing_notes.npy', allow_pickle=True).tolist()
ref_note_events_dict = np.load('extras/examples/singing_note_events.npy', allow_pickle=True).tolist()
est_notes, _ = note_event2note(ref_note_events_dict['note_events'])
ref_notes = ref_notes_dict['notes']
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in SINGING_SOLO_CLASS.keys()])
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes,
ref_notes,
eval_vocab=SINGING_SOLO_CLASS,
eval_drum_vocab=None,
onset_tolerance=0.05)
metric.bulk_update(drum_metric)
metric.bulk_update(non_drum_metric)
metric.bulk_update(instr_metric)
computed_metrics = metric.bulk_compute()
cnt = 0
for k, v in computed_metrics.items():
if 'Singing Voice' in k:
self.assertEqual(v, 1.0)
cnt += 1
self.assertEqual(cnt, 6)
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in GM_INSTR_CLASS_PLUS.keys()])
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes,
ref_notes,
eval_vocab=GM_INSTR_CLASS_PLUS,
eval_drum_vocab=None,
onset_tolerance=0.05)
metric.bulk_update(drum_metric)
metric.bulk_update(non_drum_metric)
metric.bulk_update(instr_metric)
computed_metrics = metric.bulk_compute()
cnt = 0
for k, v in computed_metrics.items():
if 'Singing Voice' in k:
self.assertEqual(v, 1.0)
cnt += 1
self.assertEqual(cnt, 6)
if __name__ == '__main__':
unittest.main()