Nguyen Tien commited on
Commit
0881b5a
·
1 Parent(s): 233d8e2

Update modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +121 -99
modeling_mpt.py CHANGED
@@ -1,69 +1,57 @@
1
  """A simple, flexible implementation of a GPT model.
2
-
3
  Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
  import math
6
  import warnings
7
- from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from transformers import PreTrainedModel, PreTrainedTokenizerBase
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
- from .custom_embedding import SharedEmbedding
16
- from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
17
- from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
18
- from .ffn import MPTMLP as MPTMLP
19
- from .ffn import build_ffn as build_ffn
20
  from .norm import NORM_CLASS_REGISTRY
21
  from .configuration_mpt import MPTConfig
22
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
23
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
24
  from .meta_init_context import init_empty_weights
25
- from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
26
- try:
27
- from .flash_attn_triton import flash_attn_func as flash_attn_func
28
- except:
29
- pass
30
- import logging
31
- log = logging.getLogger(__name__)
32
 
33
  class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
35
  base_model_prefix = 'model'
36
- _no_split_modules = ['MPTBlock']
 
 
 
 
 
37
 
38
  class MPTModel(MPTPreTrainedModel):
39
 
40
  def __init__(self, config: MPTConfig):
41
  config._validate_config()
42
  super().__init__(config)
 
43
  self.attn_impl = config.attn_config['attn_impl']
44
  self.prefix_lm = config.attn_config['prefix_lm']
45
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
46
  self.alibi = config.attn_config['alibi']
47
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
48
- self.learned_pos_emb = config.learned_pos_emb
49
- if config.init_device == 'mixed':
50
- if dist.get_local_rank() == 0:
51
- config.init_device = 'cpu'
52
- else:
53
- config.init_device = 'meta'
54
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
55
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
56
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
- self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
60
- if self.learned_pos_emb:
61
- self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
64
  self.norm_f = norm_class(config.d_model, device=config.init_device)
65
  if config.init_device != 'meta':
66
- log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
67
  self.apply(self.param_init_fn)
68
  self.is_causal = not self.prefix_lm
69
  self._attn_bias_initialized = False
@@ -72,22 +60,25 @@ class MPTModel(MPTPreTrainedModel):
72
  if config.no_bias:
73
  for module in self.modules():
74
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
75
- log.info(f'Removing bias ({module.bias}) from {module}.')
 
76
  module.register_parameter('bias', None)
77
- if hasattr(module, 'use_bias'):
78
- log.info(f'Setting use_bias=False for {module}.')
79
- module.use_bias = False
80
- log.debug(self)
81
- log.debug(f"Using {self.config.init_config['name']} initialization.")
 
 
82
 
83
- def get_input_embeddings(self) -> nn.Embedding:
84
  return self.wte
85
 
86
- def set_input_embeddings(self, value: nn.Embedding) -> None:
87
  self.wte = value
88
 
89
  @torch.no_grad()
90
- def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
91
  if not self._attn_bias_initialized:
92
  if self.attn_bias_shape:
93
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
@@ -110,15 +101,14 @@ class MPTModel(MPTPreTrainedModel):
110
  if attn_bias is None:
111
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
112
  else:
113
- _s_k = max(0, attn_bias.size(-1) - s_k)
114
- attn_bias = attn_bias[:, :, :, _s_k:]
115
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
116
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
117
  min_val = torch.finfo(attn_bias.dtype).min
118
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
119
  return (attn_bias, None)
120
 
121
- def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor:
122
  (s_k, s_q) = attn_bias.shape[-2:]
123
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
124
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
@@ -133,7 +123,7 @@ class MPTModel(MPTPreTrainedModel):
133
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
134
  return attn_bias
135
 
136
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor) -> torch.Tensor:
137
  seq_len = sequence_id.shape[-1]
138
  if seq_len > self.config.max_seq_len:
139
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
@@ -143,86 +133,122 @@ class MPTModel(MPTPreTrainedModel):
143
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
144
  return attn_bias
145
 
146
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
147
  return_dict = return_dict if return_dict is not None else self.config.return_dict
148
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  if attention_mask is not None:
150
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
151
  if prefix_mask is not None:
152
  prefix_mask = prefix_mask.bool()
153
  if not return_dict:
154
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
155
  if output_attentions:
156
- if self.attn_impl != 'torch':
157
- raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
158
- if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
159
- raise NotImplementedError('MPT does not support training with left padding.')
160
  if self.prefix_lm and prefix_mask is None:
161
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
162
- if inputs_embeds is not None:
163
- raise NotImplementedError('inputs_embeds is not implemented for MPT.')
164
  if self.training:
165
  if self.attn_uses_sequence_id and sequence_id is None:
166
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
167
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
168
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
169
- S = input_ids.size(1)
170
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
171
- tok_emb = self.wte(input_ids)
172
- if self.learned_pos_emb:
 
173
  past_position = 0
174
  if past_key_values is not None:
175
  if len(past_key_values) != self.config.n_layers:
176
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
177
  past_position = past_key_values[0][0].size(1)
178
- if self.attn_impl == 'torch':
179
- past_position = past_key_values[0][0].size(3)
180
  if S + past_position > self.config.max_seq_len:
181
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
182
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
183
- if attention_mask is not None:
184
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
185
  pos_emb = self.wpe(pos)
186
  x = tok_emb + pos_emb
187
- else:
188
- x = tok_emb
189
  if self.embedding_fraction == 1:
190
  x = self.emb_drop(x)
191
  else:
192
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
193
  assert isinstance(self.emb_drop, nn.Module)
194
  x = self.emb_drop(x_shrunk)
195
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
196
- presents = () if use_cache else None
197
  if use_cache and past_key_values is None:
198
  past_key_values = [() for _ in range(self.config.n_layers)]
 
199
  all_hidden_states = () if output_hidden_states else None
200
- all_self_attns = () if output_attentions else None
201
  for (b_idx, block) in enumerate(self.blocks):
202
  if output_hidden_states:
203
  assert all_hidden_states is not None
204
  all_hidden_states = all_hidden_states + (x,)
205
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
206
- (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
207
- if presents is not None:
208
- presents += (present,)
209
- if output_attentions:
210
- assert all_self_attns is not None
211
- all_self_attns = all_self_attns + (attn_weights,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  x = self.norm_f(x)
213
- if output_hidden_states:
214
- assert all_hidden_states is not None
215
- all_hidden_states = all_hidden_states + (x,)
216
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
217
 
218
- def param_init_fn(self, module: nn.Module) -> None:
219
  init_fn_name = self.config.init_config['name']
220
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
221
 
222
- def fsdp_wrap_fn(self, module: nn.Module) -> bool:
223
  return isinstance(module, MPTBlock)
224
 
225
- def activation_checkpointing_fn(self, module: nn.Module) -> bool:
226
  return isinstance(module, MPTBlock)
227
 
228
  class MPTForCausalLM(MPTPreTrainedModel):
@@ -231,13 +257,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
231
  super().__init__(config)
232
  if not config.tie_word_embeddings:
233
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
234
- log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
235
- self.transformer: MPTModel = MPTModel(config)
236
- for child in self.transformer.children():
237
- if isinstance(child, torch.nn.ModuleList):
238
- continue
239
- if isinstance(child, torch.nn.Module):
240
- child._fsdp_wrap = True
241
  self.logit_scale = None
242
  if config.logit_scale is not None:
243
  logit_scale = config.logit_scale
@@ -248,53 +268,56 @@ class MPTForCausalLM(MPTPreTrainedModel):
248
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
249
  self.logit_scale = logit_scale
250
 
251
- def get_input_embeddings(self) -> nn.Embedding:
252
  return self.transformer.wte
253
 
254
- def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
255
  self.transformer.wte = value
256
 
257
- def get_output_embeddings(self) -> nn.Embedding:
258
  return self.transformer.wte
259
 
260
- def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
261
  self.transformer.wte = new_embeddings
262
 
263
- def set_decoder(self, decoder: MPTModel) -> None:
264
  self.transformer = decoder
265
 
266
- def get_decoder(self) -> MPTModel:
267
  return self.transformer
268
 
269
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
270
  return_dict = return_dict if return_dict is not None else self.config.return_dict
271
  use_cache = use_cache if use_cache is not None else self.config.use_cache
272
- if inputs_embeds is not None:
273
- raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
274
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
275
- logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
 
 
 
276
  if self.logit_scale is not None:
277
  if self.logit_scale == 0:
278
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
279
  logits *= self.logit_scale
280
  loss = None
281
  if labels is not None:
282
- _labels = torch.roll(labels, shifts=-1)
283
- _labels[:, -1] = -100
284
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
285
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
286
 
287
- def param_init_fn(self, module: nn.Module) -> None:
288
  init_fn_name = self.config.init_config['name']
289
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
290
 
291
- def fsdp_wrap_fn(self, module: nn.Module) -> bool:
292
  return isinstance(module, MPTBlock)
293
 
294
- def activation_checkpointing_fn(self, module: nn.Module) -> bool:
295
  return isinstance(module, MPTBlock)
296
 
297
- def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
298
  if inputs_embeds is not None:
299
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
300
  attention_mask = kwargs['attention_mask'].bool()
@@ -315,9 +338,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
315
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
316
 
317
  @staticmethod
318
- def _reorder_cache(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], beam_idx: torch.LongTensor) -> List[Tuple[torch.Tensor, ...]]:
319
  """Used by HuggingFace generate when using beam search with kv-caching.
320
-
321
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
322
  for an example in transformers.
323
  """
 
1
  """A simple, flexible implementation of a GPT model.
 
2
  Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
3
  """
4
  import math
5
  import warnings
6
+ from typing import List, Optional, Tuple, Union
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
11
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
12
  from .attention import attn_bias_shape, build_attn_bias
13
  from .blocks import MPTBlock
 
 
 
 
 
14
  from .norm import NORM_CLASS_REGISTRY
15
  from .configuration_mpt import MPTConfig
16
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
17
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
18
  from .meta_init_context import init_empty_weights
19
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
20
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
 
 
 
 
 
21
 
22
  class MPTPreTrainedModel(PreTrainedModel):
23
  config_class = MPTConfig
24
  base_model_prefix = 'model'
25
+ _no_split_modules = ["MPTBlock"]
26
+ supports_gradient_checkpointing = True
27
+
28
+ def _set_gradient_checkpointing(self, module, value=False):
29
+ if isinstance(module, MPTModel):
30
+ module.gradient_checkpointing = value
31
 
32
  class MPTModel(MPTPreTrainedModel):
33
 
34
  def __init__(self, config: MPTConfig):
35
  config._validate_config()
36
  super().__init__(config)
37
+ self.gradient_checkpointing = False
38
  self.attn_impl = config.attn_config['attn_impl']
39
  self.prefix_lm = config.attn_config['prefix_lm']
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
 
 
 
 
 
43
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
44
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
45
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
46
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
47
  self.embedding_fraction = config.embedding_fraction
48
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
49
+ if not self.alibi:
50
+ self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
51
  self.emb_drop = nn.Dropout(config.emb_pdrop)
52
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
53
  self.norm_f = norm_class(config.d_model, device=config.init_device)
54
  if config.init_device != 'meta':
 
55
  self.apply(self.param_init_fn)
56
  self.is_causal = not self.prefix_lm
57
  self._attn_bias_initialized = False
 
60
  if config.no_bias:
61
  for module in self.modules():
62
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
63
+ if config.verbose:
64
+ warnings.warn(f'Removing bias ({module.bias}) from {module}.')
65
  module.register_parameter('bias', None)
66
+ if config.verbose and config.verbose > 2:
67
+ print(self)
68
+ if 'verbose' not in self.config.init_config:
69
+ self.config.init_config['verbose'] = self.config.verbose
70
+ if self.config.init_config['verbose'] > 1:
71
+ init_fn_name = self.config.init_config['name']
72
+ warnings.warn(f'Using {init_fn_name} initialization.')
73
 
74
+ def get_input_embeddings(self):
75
  return self.wte
76
 
77
+ def set_input_embeddings(self, value):
78
  self.wte = value
79
 
80
  @torch.no_grad()
81
+ def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
82
  if not self._attn_bias_initialized:
83
  if self.attn_bias_shape:
84
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
 
101
  if attn_bias is None:
102
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
103
  else:
104
+ attn_bias = attn_bias[:, :, :, -s_k:]
 
105
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
106
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
107
  min_val = torch.finfo(attn_bias.dtype).min
108
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
109
  return (attn_bias, None)
110
 
111
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
112
  (s_k, s_q) = attn_bias.shape[-2:]
113
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
114
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
 
123
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
124
  return attn_bias
125
 
126
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
127
  seq_len = sequence_id.shape[-1]
128
  if seq_len > self.config.max_seq_len:
129
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
133
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
134
  return attn_bias
135
 
136
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
137
  return_dict = return_dict if return_dict is not None else self.config.return_dict
138
  use_cache = use_cache if use_cache is not None else self.config.use_cache
139
+ if self.gradient_checkpointing and self.training:
140
+ if use_cache:
141
+ use_cache = False
142
+ if input_ids is not None and inputs_embeds is not None:
143
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
144
+ elif input_ids is not None:
145
+ batch_size, seq_length = input_ids.shape
146
+ elif inputs_embeds is not None:
147
+ batch_size, seq_length, _ = inputs_embeds.shape
148
+ else:
149
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
150
+
151
+ seq_length_with_past = seq_length
152
+ past_key_values_length = 0
153
+
154
+ if past_key_values is not None:
155
+ past_key_values_length = past_key_values[0][0].shape[2]
156
+ seq_length_with_past = seq_length_with_past + past_key_values_length
157
+
158
  if attention_mask is not None:
159
  attention_mask = attention_mask.bool()
160
+ else:
161
+ attention_mask = torch.ones(
162
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
163
+ )
164
+
165
+ if inputs_embeds is None:
166
+ tok_emb = self.wte(input_ids)
167
+ else:
168
+ tok_emb = inputs_embeds
169
+
170
  if prefix_mask is not None:
171
  prefix_mask = prefix_mask.bool()
172
  if not return_dict:
173
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
174
  if output_attentions:
175
+ raise NotImplementedError('output_attentions is not implemented yet for MPT')
176
+ #if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
177
+ # raise NotImplementedError('MPT does not support training with left padding.')
 
178
  if self.prefix_lm and prefix_mask is None:
179
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
180
  if self.training:
181
  if self.attn_uses_sequence_id and sequence_id is None:
182
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
183
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
184
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
185
+ S = seq_length
186
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
187
+ if self.alibi:
188
+ x = tok_emb
189
+ else:
190
  past_position = 0
191
  if past_key_values is not None:
192
  if len(past_key_values) != self.config.n_layers:
193
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
194
  past_position = past_key_values[0][0].size(1)
 
 
195
  if S + past_position > self.config.max_seq_len:
196
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
197
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
198
+ if attention_mask is not None and not self.training:
199
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
200
  pos_emb = self.wpe(pos)
201
  x = tok_emb + pos_emb
 
 
202
  if self.embedding_fraction == 1:
203
  x = self.emb_drop(x)
204
  else:
205
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
206
  assert isinstance(self.emb_drop, nn.Module)
207
  x = self.emb_drop(x_shrunk)
208
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
 
209
  if use_cache and past_key_values is None:
210
  past_key_values = [() for _ in range(self.config.n_layers)]
211
+
212
  all_hidden_states = () if output_hidden_states else None
 
213
  for (b_idx, block) in enumerate(self.blocks):
214
  if output_hidden_states:
215
  assert all_hidden_states is not None
216
  all_hidden_states = all_hidden_states + (x,)
217
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
218
+
219
+ if self.gradient_checkpointing and self.training:
220
+
221
+ def create_custom_forward(module):
222
+ def custom_forward(*inputs):
223
+ # None for past_key_value
224
+ return module(*inputs)
225
+
226
+ return custom_forward
227
+
228
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
229
+ create_custom_forward(block),
230
+ x,
231
+ past_key_value,
232
+ attn_bias,
233
+ attention_mask,
234
+ self.is_causal,
235
+ )
236
+ else:
237
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
238
+
239
+ if past_key_values is not None:
240
+ past_key_values[b_idx] = past_key_value
241
  x = self.norm_f(x)
242
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
243
 
244
+ def param_init_fn(self, module):
245
  init_fn_name = self.config.init_config['name']
246
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
247
 
248
+ def fsdp_wrap_fn(self, module):
249
  return isinstance(module, MPTBlock)
250
 
251
+ def activation_checkpointing_fn(self, module):
252
  return isinstance(module, MPTBlock)
253
 
254
  class MPTForCausalLM(MPTPreTrainedModel):
 
257
  super().__init__(config)
258
  if not config.tie_word_embeddings:
259
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
260
+ self.transformer = MPTModel(config)
 
 
 
 
 
 
261
  self.logit_scale = None
262
  if config.logit_scale is not None:
263
  logit_scale = config.logit_scale
 
268
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
269
  self.logit_scale = logit_scale
270
 
271
+ def get_input_embeddings(self):
272
  return self.transformer.wte
273
 
274
+ def set_input_embeddings(self, value):
275
  self.transformer.wte = value
276
 
277
+ def get_output_embeddings(self):
278
  return self.transformer.wte
279
 
280
+ def set_output_embeddings(self, new_embeddings):
281
  self.transformer.wte = new_embeddings
282
 
283
+ def set_decoder(self, decoder):
284
  self.transformer = decoder
285
 
286
+ def get_decoder(self):
287
  return self.transformer
288
 
289
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
290
  return_dict = return_dict if return_dict is not None else self.config.return_dict
291
  use_cache = use_cache if use_cache is not None else self.config.use_cache
292
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
293
+
294
+ last_hidden_state = outputs.last_hidden_state
295
+ if self.model_parallel:
296
+ last_hidden_state = last_hidden_state.to(self.transformer.wte.weight.device)
297
+ logits = F.linear(last_hidden_state, self.transformer.wte.weight)
298
+
299
  if self.logit_scale is not None:
300
  if self.logit_scale == 0:
301
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
302
  logits *= self.logit_scale
303
  loss = None
304
  if labels is not None:
305
+ labels = torch.roll(labels, shifts=-1)
306
+ labels[:, -1] = -100
307
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
308
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
309
 
310
+ def param_init_fn(self, module):
311
  init_fn_name = self.config.init_config['name']
312
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
313
 
314
+ def fsdp_wrap_fn(self, module):
315
  return isinstance(module, MPTBlock)
316
 
317
+ def activation_checkpointing_fn(self, module):
318
  return isinstance(module, MPTBlock)
319
 
320
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
321
  if inputs_embeds is not None:
322
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
323
  attention_mask = kwargs['attention_mask'].bool()
 
338
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
339
 
340
  @staticmethod
341
+ def _reorder_cache(past_key_values, beam_idx):
342
  """Used by HuggingFace generate when using beam search with kv-caching.
 
343
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
344
  for an example in transformers.
345
  """