cmeraki commited on
Commit
cde64ac
1 Parent(s): 4e1156c

Updated pipeline

Browse files
Files changed (1) hide show
  1. tts_pipeline.py +65 -7
tts_pipeline.py CHANGED
@@ -2,7 +2,26 @@ import re
2
  import torch
3
  import numpy as np
4
  from transformers import MimiModel, GenerationConfig
5
- from transformers import Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class IndriTTSPipeline(Pipeline):
8
  def __init__(self, *args, **kwargs):
@@ -18,6 +37,8 @@ class IndriTTSPipeline(Pipeline):
18
  self.num_codebooks = 8
19
  self.audio_offset = 50257
20
 
 
 
21
  self.model.generation_config = GenerationConfig(
22
  eos_token_id=self.stop_token,
23
  max_length=kwargs.get('max_length', 1024),
@@ -62,21 +83,55 @@ class IndriTTSPipeline(Pipeline):
62
 
63
  return acoustic_tokens
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def preprocess(self, inputs, speaker):
66
- # TODO: Check for batching
67
  input_text = self._sanitize_text(inputs)
68
  input_tokens = self.tokenizer.encode(input_text)
69
  task_tokens = self._prepare_tts_tokens(input_tokens, speaker)
70
  task_tokens = torch.tensor(task_tokens).unsqueeze(0)
71
 
72
- return {'task_tokens': task_tokens}
73
 
74
  def _forward(self, model_inputs, **forward_args):
75
 
76
- outputs = self.model.generate(model_inputs['task_tokens'])
77
- audio_tokens = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- for idx, inputs in enumerate(model_inputs['task_tokens']):
80
  truncated = outputs[idx, inputs.shape[-1]:]
81
  end = torch.where(truncated == self.stop_token[0])[-1]
82
 
@@ -89,8 +144,11 @@ class IndriTTSPipeline(Pipeline):
89
  truncated -= self.audio_offset
90
  truncated = self._deserialize_tokens(torch.tensor(truncated), self.num_codebooks)
91
  audio_tokens.append(truncated)
 
92
 
93
  audio_tokens = torch.vstack(audio_tokens).unsqueeze(0)
 
 
94
  audio = self.audio_tokenizer.decode(audio_tokens).audio_values
95
 
96
  return {
@@ -99,4 +157,4 @@ class IndriTTSPipeline(Pipeline):
99
  }
100
 
101
  def postprocess(self, model_outputs):
102
- return model_outputs
 
2
  import torch
3
  import numpy as np
4
  from transformers import MimiModel, GenerationConfig
5
+ from transformers import Pipeline, LogitsProcessor
6
+
7
+ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
8
+ def __init__(self, input_start_len: int, codebook_size: int, num_codebooks: int, offset: int, stop_token: int):
9
+ self.input_start_len = input_start_len
10
+ self.codebook_size = codebook_size
11
+ self.num_codebooks = num_codebooks
12
+ self.offset = offset
13
+ self.stop_token = stop_token
14
+
15
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
16
+ curr_len = input_ids.shape[-1]
17
+ codebook_idx = ((curr_len - self.input_start_len) % self.num_codebooks)
18
+
19
+ scores_processed = scores.clone()
20
+ scores_processed[:, : self.offset + codebook_idx * self.codebook_size] = -float("inf")
21
+ scores_processed[:, self.offset + (codebook_idx+1) * self.codebook_size :] = -float("inf")
22
+ scores_processed[:, self.stop_token] = scores[:, self.stop_token]
23
+
24
+ return scores_processed
25
 
26
  class IndriTTSPipeline(Pipeline):
27
  def __init__(self, *args, **kwargs):
 
37
  self.num_codebooks = 8
38
  self.audio_offset = 50257
39
 
40
+ self.model.stop_token = self.stop_token
41
+
42
  self.model.generation_config = GenerationConfig(
43
  eos_token_id=self.stop_token,
44
  max_length=kwargs.get('max_length', 1024),
 
83
 
84
  return acoustic_tokens
85
 
86
+ # TODO: Use this to support batching
87
+ def _prepare_mimi_batch(self, tokens, attention_mask):
88
+ max_len = max(token.size(1) for token in tokens)
89
+
90
+ padded_tokens = []
91
+ padded_masks = []
92
+
93
+ for token, mask in zip(tokens, attention_masks):
94
+ pad_len = max_len - token.size(1)
95
+
96
+ padded_token = F.pad(token, (0, pad_len, 0, 0), value=0)
97
+ padded_mask = F.pad(mask, (0, pad_len, 0, 0), value=0)
98
+
99
+ padded_tokens.append(padded_token)
100
+ padded_masks.append(padded_mask)
101
+
102
+ stacked_tokens = torch.stack(padded_tokens, dim=0)
103
+ stacked_masks = torch.stack(padded_masks, dim=0)
104
+
105
+ return stacked_tokens, stacked_masks
106
+
107
  def preprocess(self, inputs, speaker):
 
108
  input_text = self._sanitize_text(inputs)
109
  input_tokens = self.tokenizer.encode(input_text)
110
  task_tokens = self._prepare_tts_tokens(input_tokens, speaker)
111
  task_tokens = torch.tensor(task_tokens).unsqueeze(0)
112
 
113
+ return {'input_ids': task_tokens, 'attention_mask': torch.ones_like(task_tokens)}
114
 
115
  def _forward(self, model_inputs, **forward_args):
116
 
117
+ logits_processor=[
118
+ AlternatingCodebooksLogitsProcessor(
119
+ input_start_len=model_inputs['input_ids'].shape[-1],
120
+ codebook_size=2048,
121
+ num_codebooks=self.num_codebooks,
122
+ offset=self.audio_offset,
123
+ stop_token=self.stop_token
124
+ )
125
+ ]
126
+
127
+ outputs = self.model.generate(
128
+ model_inputs['input_ids'],
129
+ logits_processor=logits_processor
130
+ )
131
+
132
+ audio_tokens, attention_mask = [], []
133
 
134
+ for idx, inputs in enumerate(model_inputs['input_ids']):
135
  truncated = outputs[idx, inputs.shape[-1]:]
136
  end = torch.where(truncated == self.stop_token[0])[-1]
137
 
 
144
  truncated -= self.audio_offset
145
  truncated = self._deserialize_tokens(torch.tensor(truncated), self.num_codebooks)
146
  audio_tokens.append(truncated)
147
+ attention_mask.append(torch.ones_like(truncated))
148
 
149
  audio_tokens = torch.vstack(audio_tokens).unsqueeze(0)
150
+ attention_mask = torch.vstack(attention_mask).unsqueeze(0)
151
+
152
  audio = self.audio_tokenizer.decode(audio_tokens).audio_values
153
 
154
  return {
 
157
  }
158
 
159
  def postprocess(self, model_outputs):
160
+ return model_outputs