Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""event_codec_test.py: | |
This file contains tests for the following classes: | |
• Event | |
• EventRange | |
• FastCodec equivalent to MT3 author's Codec | |
See tokenizer_test.py for the FastCodec performance benchmark | |
""" | |
import unittest | |
from utils.note_event_dataclasses import Event, EventRange | |
from utils.event_codec import FastCodec as Codec | |
# from utils.event_codec import Codec | |
class TestEvent(unittest.TestCase): | |
def test_Event(self): | |
e = Event(type='shift', value=0) | |
self.assertEqual(e.type, 'shift') | |
self.assertEqual(e.value, 0) | |
class TestEventRange(unittest.TestCase): | |
def test_EventRange(self): | |
er = EventRange('abc', min_value=0, max_value=500) | |
self.assertEqual(er.type, 'abc') | |
self.assertEqual(er.min_value, 0) | |
self.assertEqual(er.max_value, 500) | |
class TestEventCodec(unittest.TestCase): | |
def test_event_codec(self): | |
ec = Codec( | |
special_tokens=['asd'], | |
max_shift_steps=1001, | |
event_ranges=[ | |
EventRange('pitch', min_value=0, max_value=127), | |
EventRange('velocity', min_value=0, max_value=1), | |
EventRange('tie', min_value=0, max_value=0), | |
EventRange('program', min_value=0, max_value=127), | |
EventRange('drum', min_value=0, max_value=127), | |
], | |
) | |
events = [ | |
Event(type='shift', value=0), # actually not needed | |
Event(type='shift', value=1), # 10 ms shift | |
Event(type='shift', value=1000), # 10 s shift | |
Event(type='pitch', value=0), # lowest pitch 8.18 Hz | |
Event(type='pitch', value=60), # C4 or 261.63 Hz | |
Event(type='pitch', value=127), # highest pitch G9 or 12543.85 Hz | |
Event(type='velocity', value=0), # lowest velocity) | |
Event(type='velocity', value=1), # lowest velocity) | |
Event(type='tie', value=0), # tie | |
Event(type='program', value=0), # program | |
Event(type='program', value=127), # program | |
Event(type='drum', value=0), # drum | |
Event(type='drum', value=127), # drum | |
] | |
encoded = [ec.encode_event(e) for e in events] | |
decoded = [ec.decode_event_index(idx) for idx in encoded] | |
self.assertSequenceEqual(events, decoded) | |
class TestEventCodecErrorCases(unittest.TestCase): | |
def setUp(self): | |
self.event_ranges = [ | |
EventRange("program", 0, 127), | |
EventRange("pitch", 0, 127), | |
EventRange("velocity", 0, 3), | |
EventRange("drum", 0, 127), | |
EventRange("tie", 0, 1), | |
] | |
self.ec = Codec([], 1000, self.event_ranges) | |
def test_encode_event_with_invalid_event_type(self): | |
with self.assertRaises(ValueError): | |
self.ec.encode_event(Event("unknown_event_type", 50)) | |
def test_encode_event_with_invalid_event_value(self): | |
with self.assertRaises(ValueError): | |
self.ec.encode_event(Event("program", 200)) | |
def test_event_type_range_with_invalid_event_type(self): | |
with self.assertRaises(ValueError): | |
self.ec.event_type_range("unknown_event_type") | |
def test_decode_event_index_with_invalid_index(self): | |
with self.assertRaises(ValueError): | |
self.ec.decode_event_index(1000000) | |
class TestEventCodecVocabulary(unittest.TestCase): | |
def test_encode_event_using_program_vocabulary(self): | |
prog_vocab = {"Piano": [0, 1, 2, 3, 4, 5, 6, 7], "xxx": [50, 30, 120]} | |
ec = Codec(special_tokens=['asd'], | |
max_shift_steps=1001, | |
event_ranges=[ | |
EventRange('pitch', min_value=0, max_value=127), | |
EventRange('velocity', min_value=0, max_value=1), | |
EventRange('tie', min_value=0, max_value=0), | |
EventRange('program', min_value=0, max_value=127), | |
EventRange('drum', min_value=0, max_value=127), | |
], | |
program_vocabulary=prog_vocab) | |
events = [ | |
Event(type='program', value=0), # 0 --> 0 | |
Event(type='program', value=7), # 7 --> 0 | |
Event(type='program', value=111), # 111 --> 111 | |
Event(type='program', value=30), # 30 --> 50 | |
] | |
encoded = [ec.encode_event(e) for e in events] | |
expected = [1133, 1133, 1244, 1183] | |
self.assertSequenceEqual(encoded, expected) | |
def test_encode_event_using_drum_vocabulary(self): | |
drum_vocab = {"Kick": [50, 51, 52], "Snare": [53, 54]} | |
ec = Codec(special_tokens=['asd'], | |
max_shift_steps=1001, | |
event_ranges=[ | |
EventRange('pitch', min_value=0, max_value=127), | |
EventRange('velocity', min_value=0, max_value=1), | |
EventRange('tie', min_value=0, max_value=0), | |
EventRange('program', min_value=0, max_value=127), | |
EventRange('drum', min_value=0, max_value=127), | |
], | |
drum_vocabulary=drum_vocab) | |
events = [ | |
Event(type='drum', value=50), | |
Event(type='drum', value=51), | |
Event(type='drum', value=53), | |
Event(type='drum', value=54), | |
] | |
encoded = [ec.encode_event(e) for e in events] | |
self.assertEqual(encoded[0], encoded[1]) | |
self.assertEqual(encoded[2], encoded[3]) | |
if __name__ == '__main__': | |
unittest.main() | |