Pijush2023 commited on
Commit
042fb34
·
verified ·
1 Parent(s): 2288ea6

Delete CHATTS/model/gpt.py

Browse files
Files changed (1) hide show
  1. CHATTS/model/gpt.py +0 -265
CHATTS/model/gpt.py DELETED
@@ -1,265 +0,0 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
-
4
- import logging
5
- from tqdm import tqdm
6
- from einops import rearrange
7
- from transformers.cache_utils import Cache
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torch.nn.utils.parametrize as P
13
- from torch.nn.utils.parametrizations import weight_norm
14
- from transformers import LlamaModel, LlamaConfig
15
-
16
-
17
- class LlamaMLP(nn.Module):
18
- def __init__(self, hidden_size, intermediate_size):
19
- super().__init__()
20
- self.hidden_size = hidden_size
21
- self.intermediate_size = intermediate_size
22
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
23
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
24
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
25
- self.act_fn = F.silu
26
-
27
- def forward(self, x):
28
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
29
- return down_proj
30
-
31
-
32
- class GPT_warpper(nn.Module):
33
- def __init__(
34
- self,
35
- gpt_config,
36
- num_audio_tokens,
37
- num_text_tokens,
38
- num_vq=4,
39
- **kwargs,
40
- ):
41
- super().__init__()
42
-
43
- self.logger = logging.getLogger(__name__)
44
- self.gpt = self.build_model(gpt_config)
45
- self.model_dim = self.gpt.config.hidden_size
46
-
47
- self.num_vq = num_vq
48
- self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
49
- self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
50
- self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
51
- self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
52
-
53
- def build_model(self, config):
54
-
55
- configuration = LlamaConfig(**config)
56
- model = LlamaModel(configuration)
57
- del model.embed_tokens
58
-
59
- return model
60
-
61
- def get_emb(self, input_ids, text_mask, **kwargs):
62
-
63
- emb_text = self.emb_text(input_ids[text_mask][:, 0])
64
-
65
- emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
66
- emb_code = torch.stack(emb_code, 2).sum(2)
67
-
68
- emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
69
- emb[text_mask] = emb_text
70
- emb[~text_mask] = emb_code.to(emb.dtype)
71
-
72
- return emb
73
-
74
- def prepare_inputs_for_generation(
75
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
76
- ):
77
- # With static cache, the `past_key_values` is None
78
- # TODO joao: standardize interface for the different Cache classes and remove of this if
79
- has_static_cache = False
80
- if past_key_values is None:
81
- past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
82
- has_static_cache = past_key_values is not None
83
-
84
- past_length = 0
85
- if past_key_values is not None:
86
- if isinstance(past_key_values, Cache):
87
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
88
- max_cache_length = (
89
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
90
- if past_key_values.get_max_length() is not None
91
- else None
92
- )
93
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
94
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
95
- else:
96
- cache_length = past_length = past_key_values[0][0].shape[2]
97
- max_cache_length = None
98
-
99
- # Keep only the unprocessed tokens:
100
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
101
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
102
- # input)
103
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
104
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
105
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
106
- # input_ids based on the past_length.
107
- elif past_length < input_ids.shape[1]:
108
- input_ids = input_ids[:, past_length:]
109
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
110
-
111
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
112
- if (
113
- max_cache_length is not None
114
- and attention_mask is not None
115
- and cache_length + input_ids.shape[1] > max_cache_length
116
- ):
117
- attention_mask = attention_mask[:, -max_cache_length:]
118
-
119
- position_ids = kwargs.get("position_ids", None)
120
- if attention_mask is not None and position_ids is None:
121
- # create position_ids on the fly for batch generation
122
- position_ids = attention_mask.long().cumsum(-1) - 1
123
- position_ids.masked_fill_(attention_mask == 0, 1)
124
- if past_key_values:
125
- position_ids = position_ids[:, -input_ids.shape[1] :]
126
-
127
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
128
- if inputs_embeds is not None and past_key_values is None:
129
- model_inputs = {"inputs_embeds": inputs_embeds}
130
- else:
131
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
132
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
133
- # TODO: use `next_tokens` directly instead.
134
- model_inputs = {"input_ids": input_ids.contiguous()}
135
-
136
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
137
- if cache_position is None:
138
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
139
- else:
140
- cache_position = cache_position[-input_length:]
141
-
142
- if has_static_cache:
143
- past_key_values = None
144
-
145
- model_inputs.update(
146
- {
147
- "position_ids": position_ids,
148
- "cache_position": cache_position,
149
- "past_key_values": past_key_values,
150
- "use_cache": kwargs.get("use_cache"),
151
- "attention_mask": attention_mask,
152
- }
153
- )
154
- return model_inputs
155
-
156
- def generate(
157
- self,
158
- emb,
159
- inputs_ids,
160
- temperature,
161
- eos_token,
162
- attention_mask = None,
163
- max_new_token = 2048,
164
- min_new_token = 0,
165
- LogitsWarpers = [],
166
- LogitsProcessors = [],
167
- infer_text=False,
168
- return_attn=False,
169
- return_hidden=False,
170
- ):
171
-
172
- with torch.no_grad():
173
-
174
- attentions = []
175
- hiddens = []
176
-
177
- start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
178
- finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
179
-
180
- temperature = temperature[None].expand(inputs_ids.shape[0], -1)
181
- temperature = rearrange(temperature, "b n -> (b n) 1")
182
-
183
- attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
184
- if attention_mask is not None:
185
- attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
186
-
187
- for i in tqdm(range(max_new_token)):
188
-
189
- model_input = self.prepare_inputs_for_generation(inputs_ids,
190
- outputs.past_key_values if i!=0 else None,
191
- attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
192
-
193
- if i == 0:
194
- model_input['inputs_embeds'] = emb
195
- else:
196
- if infer_text:
197
- model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
198
- else:
199
- code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
200
- model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
201
-
202
- model_input['input_ids'] = None
203
- outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
204
- attentions.append(outputs.attentions)
205
- hidden_states = outputs[0] # 🐻
206
- if return_hidden:
207
- hiddens.append(hidden_states[:, -1])
208
-
209
- with P.cached():
210
- if infer_text:
211
- logits = self.head_text(hidden_states)
212
- else:
213
- logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
214
-
215
- logits = logits[:, -1].float()
216
-
217
- if not infer_text:
218
- logits = rearrange(logits, "b c n -> (b n) c")
219
- logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
220
- else:
221
- logits_token = inputs_ids[:, start_idx:, 0]
222
-
223
- logits = logits / temperature
224
-
225
- for logitsProcessors in LogitsProcessors:
226
- logits = logitsProcessors(logits_token, logits)
227
-
228
- for logitsWarpers in LogitsWarpers:
229
- logits = logitsWarpers(logits_token, logits)
230
-
231
- if i < min_new_token:
232
- logits[:, eos_token] = -torch.inf
233
-
234
- scores = F.softmax(logits, dim=-1)
235
-
236
- idx_next = torch.multinomial(scores, num_samples=1)
237
-
238
- if not infer_text:
239
- idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
240
- finish = finish | (idx_next == eos_token).any(1)
241
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
242
- else:
243
- finish = finish | (idx_next == eos_token).any(1)
244
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
245
-
246
- end_idx = end_idx + (~finish).int()
247
-
248
- if finish.all():
249
- break
250
-
251
- inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
252
- inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
253
-
254
- if return_hidden:
255
- hiddens = torch.stack(hiddens, 1)
256
- hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
257
-
258
- if not finish.all():
259
- self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
260
-
261
- return {
262
- 'ids': inputs_ids,
263
- 'attentions': attentions,
264
- 'hiddens':hiddens,
265
- }