File size: 5,835 Bytes
3ae1cd5
 
 
 
cde64ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ae1cd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde64ac
 
3ae1cd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde64ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ae1cd5
 
 
 
 
 
cde64ac
3ae1cd5
 
 
cde64ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ae1cd5
cde64ac
3ae1cd5
 
 
 
 
 
 
 
 
 
 
 
cde64ac
3ae1cd5
 
cde64ac
 
3ae1cd5
 
 
 
 
 
 
 
cde64ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import re
import torch
import numpy as np
from transformers import MimiModel, GenerationConfig
from transformers import Pipeline, LogitsProcessor

class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
    def __init__(self, input_start_len: int, codebook_size: int, num_codebooks: int, offset: int, stop_token: int):
        self.input_start_len = input_start_len
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        self.offset = offset
        self.stop_token = stop_token
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        curr_len = input_ids.shape[-1]
        codebook_idx = ((curr_len - self.input_start_len) % self.num_codebooks)
        
        scores_processed = scores.clone()
        scores_processed[:, : self.offset + codebook_idx * self.codebook_size] = -float("inf")
        scores_processed[:, self.offset + (codebook_idx+1) * self.codebook_size :] = -float("inf")
        scores_processed[:, self.stop_token] = scores[:, self.stop_token]

        return scores_processed

class IndriTTSPipeline(Pipeline):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.audio_tokenizer = MimiModel.from_pretrained('kyutai/mimi').to(device=self.device)

        # TODO: Ideally all of this should come from model config
        self.convert_token = self.tokenizer.encode('[convert]')
        self.stop_token = self.tokenizer.encode('[stop]')
        self.text_modality_token = self.tokenizer.encode('[text]')
        self.acoustic_modality_token = self.tokenizer.encode('[mimi]')
        self.num_codebooks = 8
        self.audio_offset = 50257

        self.model.stop_token = self.stop_token

        self.model.generation_config = GenerationConfig(
            eos_token_id=self.stop_token,
            max_length=kwargs.get('max_length', 1024),
            temperature=kwargs.get('temperature', 0.5),
            top_k=kwargs.get('top_k', 15),
            do_sample=kwargs.get('do_sample', True)
        )

    def _sanitize_parameters(self, **kwargs):
        speaker = kwargs.get('speaker', '[spkr_unk]')

        preprocess_kwargs = {
            'speaker': speaker
        }

        return preprocess_kwargs, {}, {}

    def _prepare_tts_tokens(self, text_tokens, speaker):
        input_tokens = np.hstack([
            self.text_modality_token,
            text_tokens,
            self.convert_token,
            self.acoustic_modality_token,
            self.tokenizer.encode(speaker)
        ])

        return input_tokens.tolist()

    def _sanitize_text(self, text):
        text = text.lower()
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'[ \t]+', ' ', text)

        text = re.sub(r'([,\.?])+', r'\1', text)

        return text.strip()

    def _deserialize_tokens(self, tokens, num_codebooks):
        cb = [tokens[i::num_codebooks] for i in range(num_codebooks)]
        min_shape = min([c.shape for c in cb])[0]
        acoustic_tokens = torch.vstack([c[:min_shape] - 2048*i for i, c in enumerate(cb)])

        return acoustic_tokens

    # TODO: Use this to support batching
    def _prepare_mimi_batch(self, tokens, attention_mask):
        max_len = max(token.size(1) for token in tokens)

        padded_tokens = []
        padded_masks = []

        for token, mask in zip(tokens, attention_masks):
            pad_len = max_len - token.size(1)

            padded_token = F.pad(token, (0, pad_len, 0, 0), value=0)
            padded_mask = F.pad(mask, (0, pad_len, 0, 0), value=0)

            padded_tokens.append(padded_token)
            padded_masks.append(padded_mask)

        stacked_tokens = torch.stack(padded_tokens, dim=0)
        stacked_masks = torch.stack(padded_masks, dim=0)

        return stacked_tokens, stacked_masks

    def preprocess(self, inputs, speaker):
        input_text = self._sanitize_text(inputs)
        input_tokens = self.tokenizer.encode(input_text)
        task_tokens = self._prepare_tts_tokens(input_tokens, speaker)
        task_tokens = torch.tensor(task_tokens).unsqueeze(0)

        return {'input_ids': task_tokens, 'attention_mask': torch.ones_like(task_tokens)}

    def _forward(self, model_inputs, **forward_args):

        logits_processor=[
            AlternatingCodebooksLogitsProcessor(
                input_start_len=model_inputs['input_ids'].shape[-1],
                codebook_size=2048,
                num_codebooks=self.num_codebooks,
                offset=self.audio_offset,
                stop_token=self.stop_token
            )
        ]

        outputs = self.model.generate(
            model_inputs['input_ids'],
            logits_processor=logits_processor
        )

        audio_tokens, attention_mask = [], []

        for idx, inputs in enumerate(model_inputs['input_ids']):
            truncated = outputs[idx, inputs.shape[-1]:]
            end = torch.where(truncated == self.stop_token[0])[-1]
    
            if end.shape[-1] > 0:
                end = end[0]
            else:
                end = truncated.shape[-1]
    
            truncated = truncated[:end]
            truncated -= self.audio_offset
            truncated = self._deserialize_tokens(torch.tensor(truncated), self.num_codebooks)
            audio_tokens.append(truncated)
            attention_mask.append(torch.ones_like(truncated))

        audio_tokens = torch.vstack(audio_tokens).unsqueeze(0)
        attention_mask = torch.vstack(attention_mask).unsqueeze(0)

        audio = self.audio_tokenizer.decode(audio_tokens).audio_values

        return {
            'audio_tokens': audio_tokens, # (B, num_codebooks, num_samples)
            'audio': audio # (B, 1, num_audio_samples)
        }

    def postprocess(self, model_outputs):
        return model_outputs