File size: 1,853 Bytes
9e34a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fam.llm.adapters.base import BaseDataAdapter


class TiltedEncodec(BaseDataAdapter):
    def __init__(self, end_of_audio_token):
        self._end_of_audio_token = end_of_audio_token

    def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
        assert len(tokens) > 1

        text_ids = []
        extracted_audio_ids = []

        extracted_audio_ids.append([])
        # Handle first hierarchy as special case as it contains text tokens as well
        # TODO: maybe it doesn't need special case, and can be handled on it's own :)
        for t in tokens[0]:
            if t > self._end_of_audio_token:
                text_ids.append(t)
            elif t < self._end_of_audio_token:
                extracted_audio_ids[0].append(t)

        # Handle the rest of the hierarchies
        for i in range(1, len(tokens)):
            token_hierarchy_ids = tokens[i]
            extracted_audio_ids.append([])
            for t in token_hierarchy_ids:
                if t < self._end_of_audio_token:
                    extracted_audio_ids[i].append(t)

        if len(set([len(x) for x in extracted_audio_ids])) != 1:
            min_len = min([len(x) for x in extracted_audio_ids])
            max_len = max([len(x) for x in extracted_audio_ids])
            print("WARNING: Number of tokens at each hierarchy must be of the same length!")
            print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
            print([len(x) for x in extracted_audio_ids])
            extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]

        return text_ids[:-1], extracted_audio_ids

    def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
        """
        Performs the required combination and padding as needed.
        """
        raise NotImplementedError