File size: 20,055 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
import numpy as np
from typing import Optional, Union, Tuple, Dict, Any, List, Counter
from utils.note_event_dataclasses import NoteEvent, Event, NoteEventListsBundle
from config.task import task_cfg
from config.config import model_cfg
from utils.tokenizer import NoteEventTokenizer
from utils.utils import create_program2channel_vocab
from utils.note2event import separate_channel_by_program_group_from_note_event_lists_bundle

SINGING_PROGRAM = 100
DRUM_PROGRAM = 128
UNANNOTATED_PROGRAM = 129

# import random
# class RandomProgramSampler:
#     def __init__(self, program_vocab: Dict[str, int], max_n: int = 7):
#         for key, values in program_vocab.items():
#             for value in values:
#                 self.inverse_vocab_program[value] = values[0]
#         self.max_n = max_n
#         self.shuffled_

#     def sample(self):

# def shuffle_and_repeat_randomly(lst, max_n=5):
#     shuffled = lst.copy()
#     random.shuffle(shuffled)
#     index = 0

#     while True:
#         if index >= len(shuffled):  # ๋ฆฌ์ŠคํŠธ์˜ ๋ชจ๋“  ์š”์†Œ๊ฐ€ ์‚ฌ์šฉ๋˜๋ฉด, ๋‹ค์‹œ ์…”ํ”Œ
#             random.shuffle(shuffled)
#             index = 0

#         n = random.randint(1, max_n)  # 1๊ณผ max_n ์‚ฌ์ด์˜ ๋žœ๋คํ•œ ๊ฐœ์ˆ˜ ๊ฒฐ์ •
#         end_index = index + n

#         if end_index > len(shuffled):  # ๋ฆฌ์ŠคํŠธ์˜ ๋์„ ๋„˜์–ด๊ฐ€๋Š” ๊ฒฝ์šฐ, ๋ฆฌ์ŠคํŠธ์˜ ๋๊นŒ์ง€๋งŒ ๋ฐ˜ํ™˜
#             yield shuffled[index:]
#             index = len(shuffled)
#         else:
#             yield shuffled[index:end_index]
#             index = end_index


class TaskManager:
    """
    The TaskManager class manages tasks for training. It is initialized with a task name and retrieves 
    the corresponding configuration from the task_cfg dictionary defined in config/task.py.

    Attributes:
        # Basic
        task_name (str): The name of the task being managed.
        base_codec (str): The base codec associated with the task.
        train_program_vocab (dict): The program vocabulary used for training.
        train_drum_vocab (dict): The drum vocabulary used for training.
        subtask_tokens (list): Additional tokens specific to subtasks, if any.
        extra_tokens (list): Extra tokens used in the task, including subtask tokens.
        ignore_decoding_tokens (list): Tokens to ignore during decoding.
        ignore_decoding_tokens_by_delimiter (Optional, list[str, str]): Tokens to ignore during decoding by delimiters. Default is None.
        tokenizer (NoteEventTokenizer): An instance of the NoteEventTokenizer class for tokenizing note events.
        eval_subtask_prefix (dict): A dictionary defining evaluation subtask prefixes to tokens.

        # Multi-channel decoding task exclusive
        num_decoding_channels (int): The number of decoding channels.
        max_token_length_per_ch (int): The maximum token length per channel.
        mask_loss_strategy (str): The mask loss strategy to use. NOT IMPLEMENTED YET.
        program2channel_vocab (dict): A dictionary mapping program to channel.

    Methods:
        get_tokenizer(): Returns the tokenizer instance associated with the task.
        set_tokenizer(): Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters.
    """

    def __init__(self, task_name: str = "mt3_full_plus", max_shift_steps: int = 206, debug_mode: bool = False):
        """
        Initializes a TaskManager object with the specified task name.

        Args:
            task_name (str): The name of the task to manage.
            max_shift_steps (int): The maximum shift steps for the tokenizer. Default is 206. Definable in config/config.py.
            debug_mode (bool): Whether to enable debug mode. Default is False.
        """
        self.debug_mode = debug_mode
        self.task_name = task_name

        if task_name not in task_cfg.keys():
            raise ValueError("Invalid task name")
        else:
            self.task = task_cfg[task_name]

        # Basic task parameters
        self.base_codec = self.task.get("base_codec", "mt3")
        self.train_program_vocab = self.task["train_program_vocab"]
        self.train_drum_vocab = self.task["train_drum_vocab"]
        self.subtask_tokens = self.task.get("subtask_tokens", [])
        self.extra_tokens = self.subtask_tokens + self.task.get("extra_tokens", [])
        self.ignore_decoding_tokens = self.task.get("ignore_decoding_tokens", [])
        self.ignore_decoding_tokens_from_and_to = self.task.get("ignore_decoding_tokens_from_and_to", None)
        self.max_note_token_length = self.task.get("max_note_token_length", model_cfg["event_length"])
        self.max_task_token_length = self.task.get("max_task_token_length", 0)
        self.padding_task_token = self.task.get("padding_task_token", False)
        self._eval_subtask_prefix = self.task.get("eval_subtask_prefix", None)
        self.eval_subtask_prefix_dict = {}

        # Multi-channel decoding exclusive parameters
        self.num_decoding_channels = self.task.get("num_decoding_channels", 1)
        if self.num_decoding_channels > 1:
            program2channel_vocab_source = self.task.get("program2channel_vocab_source", None)
            if program2channel_vocab_source is None:
                program2channel_vocab_source = self.train_program_vocab

            # Create an inverse mapping of program to channel
            if self.num_decoding_channels == len(program2channel_vocab_source) + 1:
                self.program2channel_vocab, _ = create_program2channel_vocab(program2channel_vocab_source)
            else:
                raise ValueError("Invalid num_decoding_channels, or program2channel_vocab not provided")

            self.max_note_token_length_per_ch = self.task.get("max_note_token_length_per_ch")
            self.mask_loss_strategy = self.task.get("mask_loss_strategy", None)  # Not implemented yet
        else:
            self.max_note_token_length_per_ch = self.max_note_token_length

        # Define max_total_token_length
        self.max_total_token_length = self.max_note_token_length_per_ch + self.max_task_token_length

        # Max shift steps for the tokenizer
        self.max_shift_steps = max_shift_steps

        # Initialize a tokenizer
        self.set_tokenizer()
        self.set_eval_task_prefix()
        self.num_tokens = self.tokenizer.num_tokens
        self.inverse_vocab_program = self.tokenizer.codec.inverse_vocab_program

    def set_eval_task_prefix(self) -> None:
        """
        Sets the evaluation task prefix for the task.

        Example:
            self.eval_task_prefix_dict = {
                "default": [Event("transcribe_all", 0), Event("task", 0)],
                "singing-only": [Event("transcribe_singing", 0), Event("task", 0)]
                }
        """
        if self._eval_subtask_prefix is not None:
            assert "default" in self._eval_subtask_prefix.keys()
            for key, val in self._eval_subtask_prefix.items():
                if self.padding_task_token:
                    self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task(
                        val, max_length=self.max_task_token_length)
                else:
                    self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task(val)
        else:
            self.eval_subtask_prefix_dict["default"] = []

    def get_eval_subtask_prefix_dict(self) -> dict:
        return self.eval_subtask_prefix_dict

    def get_tokenizer(self) -> NoteEventTokenizer:
        """
        Returns the tokenizer instance associated with the task.

        Returns:
            NoteEventTokenizer: The tokenizer instance.
        """
        return self.tokenizer

    def set_tokenizer(self) -> None:
        """
        Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters.
        """
        self.tokenizer = NoteEventTokenizer(base_codec=self.base_codec,
                                            max_length=self.max_total_token_length,
                                            program_vocabulary=self.train_program_vocab,
                                            drum_vocabulary=self.train_drum_vocab,
                                            special_tokens=['PAD', 'EOS', 'UNK'],
                                            extra_tokens=self.extra_tokens,
                                            max_shift_steps=self.max_shift_steps,
                                            ignore_decoding_tokens=self.ignore_decoding_tokens,
                                            ignore_decoding_tokens_from_and_to=self.ignore_decoding_tokens_from_and_to,
                                            debug_mode=self.debug_mode)

    # Newly implemented for exclusive transcription task
    def tokenize_task_and_note_events_batch(
            self,
            programs_segments: List[List[int]],
            has_unannotated_segments: List[bool],
            note_event_segments: NoteEventListsBundle,
            subunit_programs_segments: Optional[List[List[np.ndarray]]] = None,  # TODO
            subunit_note_event_segments: Optional[List[NoteEventListsBundle]] = None,  # TODO
            stage: str = 'train'  # 'train' or 'eval'
    ):
        """Tokenizes a batch of note events into a batch of encoded tokens.
           Optionally, appends task tokens to the note event tokens.
        
        Args:
            programs_segments (List[int]): A list of program numbers.
            has_unannotated_segments (bool): Whether the batch has unannotated segments.
            note_event_segments (NoteEventListsBundle): A bundle of note events.
            subunit_programs_segments (Optional[List[List[np.ndarray]]]): A list of subunit programs.
            subunit_note_event_segments (Optional[List[NoteEventListsBundle]]): A list of subunit note events.
        
        Returns:
            np.ndarray: A batch of encoded tokens, with shape (B, C, L).        
        """
        if self.task_name == 'exclusive':
            # batch_sz = len(programs_segments)
            # token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_note_token_length_per_ch),
            #                        dtype=np.int32)

            # for programs, has_unannotated, note_events, tie_note_events, start_times in zip(
            #         programs_segments, has_unannotated_segments, note_event_segments['note_events'],
            #         note_event_segments['tie_note_events'], note_event_segments['start_times']):
            #     if has_unannotated:
            #         annotated_programs = [p for p in programs if p != UNANNOTATED_PROGRAM]
            #         note_token_array = self.tokenizer.encode_plus(note_events,
            #                                                       tie_note_events,
            #                                                       start_times,
            #                                                       pad_to_max_length=False) # will append EOS token
            #         task_token_array = self.tokenizer.encode_task(task_events)
            #     else:
            #         annotated_programs = programs

            #     task_events = [Event('transcribe_all', 0), Event('task', 0)]
            #     note_token_array = self.tokenize_note_events_batch(note_events)
            #     task_token_array = self.tokenize_task_events(annotated_programs, has_unannotated)
            # return []
            raise NotImplementedError("Exclusive transcription task is not implemented yet.")
        else:
            # Default task: single or multi-channel decoding, without appending task tokens
            return self.tokenize_note_events_batch(note_event_segments)  # (B, C, L)
            # Exclusive transcription task
            # if has_unannotated_segments:
            #     annotated_programs = [p for p in programs_segments if p != UNANNOTATED_PROGRAM]
            # else:
            #     annotated_programs = programs_segments

            # # Main task: transcribe all
            # main_task_events = self.task.get("eval_subtask_prefix")

    def tokenize_note_events_batch(self,
                                   note_event_segments: NoteEventListsBundle,
                                   start_time_to_zero: bool = False,
                                   sort: bool = True) -> np.ndarray:
        """Tokenizes a batch of note events into a batch of encoded tokens.
        
        Args:
            note_event_segments (NoteEventListsBundle): A bundle of note events.
        
        Returns:
            np.ndarray: A batch of encoded tokens, with shape (B, C, L).        
        """
        batch_sz = len(note_event_segments["note_events"])
        note_token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_note_token_length_per_ch),
                                    dtype=np.int32)

        if self.num_decoding_channels == 1:
            # Single-channel decoding task
            zipped_events = list(zip(*note_event_segments.values()))
            for b in range(batch_sz):
                note_token_array[b, 0, :] = self.tokenizer.encode_plus(*zipped_events[b],
                                                                       max_length=self.max_note_token_length,
                                                                       pad_to_max_length=True)
        elif self.num_decoding_channels > 1:
            # Multi-channel decoding task
            ch_sep_ne_bundle = separate_channel_by_program_group_from_note_event_lists_bundle(
                source_note_event_lists_bundle=note_event_segments,
                num_program_groups=self.num_decoding_channels,
                program2channel_vocab=self.program2channel_vocab,
                start_time_to_zero=start_time_to_zero,
                sort=sort)  # (batch_sz,)

            for b in range(batch_sz):
                zipped_channel = list(zip(*ch_sep_ne_bundle[b].values()))
                for c in range(self.num_decoding_channels):
                    note_token_array[b, c, :] = self.tokenizer.encode_plus(*zipped_channel[c],
                                                                           max_length=self.max_note_token_length_per_ch,
                                                                           pad_to_max_length=True)
        return note_token_array  # (B, C, L)

    def tokenize_note_events(self,
                             note_events: List[NoteEvent],
                             tie_note_events: Optional[List[NoteEvent]] = None,
                             start_time: float = 0.,
                             **kwargs: Any) -> List[int]:
        """(Deprecated) Tokenizes a sequence of note events into a sequence of encoded tokens."""
        return self.tokenizer.encode_plus(note_events, tie_note_events, start_time, **kwargs)


# # This will be deprecated, currently used by datasets_eval.py

#     def tokenize_task_events_batch(self, programs_segments: List[int],
#                                    has_unannotated_segments: List[bool]) -> List[int]:
#         """Tokenizes batch of task tokens from annotation info.

#         Args:
#             programs_segments (List[int]): A list of program numbers.
#             has_unannotated_segments (bool): Whether the batch has unannotated segments.

#         Returns:
#             np.ndarray: Shape (B, C, L).

#         """
#         batch_sz = len(programs_segments)
#         task_token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_task_token_length), dtype=np.int32)

#         if self.max_task_token_length == 0:
#             return task_token_array

#         if self.num_decoding_channels == 1:
#             for b in range(batch_sz):
#                 task_token_array[b, 0, :] = self.tokenize_task_events(programs_segments[b], has_unannotated_segments[b])
#         elif self.num_decoding_channels > 1:
#             for b in range(batch_sz):
#                 task_token_array[b, :, :] = self.tokenize_task_events(programs_segments[b], has_unannotated_segments[b])
#         return task_token_array  # (B, C, L)

    def tokenize_task_events(self, programs: List[int], has_unannotated: bool) -> List[int]:
        """Tokenizes a sequence of programs into a sequence of encoded tokens. Used for training."""
        if self.task_name == 'singing_drum_v1':
            if has_unannotated:
                if SINGING_PROGRAM in programs:
                    task_events = [Event('transcribe_singing', 0), Event('task', 0)]
                elif DRUM_PROGRAM in programs:
                    task_events = [Event('transcribe_drum', 0), Event('task', 0)]
            else:
                task_events = [Event('transcribe_all', 0), Event('task', 0)]
        else:
            return []

        if self.padding_task_token:
            return self.tokenizer.encode_task(task_events, max_length=self.max_task_token_length)
        else:
            return self.tokenizer.encode_task(task_events)

    def detokenize(
        self,
        tokens: List[int],
        start_time: float = 0.,
        return_events: bool = False
    ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]:
        """Decodes a sequence of tokens into note events, ignoring specific token IDs.
        Returns:
            Union[Tuple[List[NoteEvent], List[NoteEvent]],
                Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events.
            If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, and `err_cnt`.
            If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, `events`, and `err_cnt`.

        Notes:
            This decoding process ignores specific token IDs based on `self.ids_to_ignore_decoding` attribute.
        """
        return self.tokenizer.decode(tokens=tokens, start_time=start_time, return_events=return_events)

    def detokenize_list_batches(
        self,
        list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]],
        list_start_times: Union[List[List[float]], List[float]],
        return_events: bool = False
    ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], Counter[str]], Tuple[
            List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], List[List[Event]], Counter[str]]]:
        """ Decodes a list of variable size batches of token array to a list of
            zipped note_events and tie_note_events.

        Args:
            list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length)
            list_start_times: List[float], where the length is sum of all batch_sizes.
            return_events: bool
        
        Returns:
            list_list_zipped_note_events_and_tie:
                List[
                    Tuple[
                        List[NoteEvent]: A list of note events.
                        List[NoteEvent]: A list of tie note events.
                        List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful
                            for validating notes within a batch of segments extracted from a file.
                        List[float]: A list of segment start times.
                    ]
                ]
            (Optional) list_events:
                List[List[Event]]
            total_err_cnt:
                Counter[str]: error counter.

        """
        return self.tokenizer.decode_list_batches(list_batch_tokens, list_start_times, return_events)