Hajime Yagihara commited on
Commit
577efb5
·
1 Parent(s): 8589ada

add inputs_embeds params to model

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +31 -8
modeling_mpt.py CHANGED
@@ -140,11 +140,30 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
 
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
146
  if attention_mask is not None:
147
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
148
  if prefix_mask is not None:
149
  prefix_mask = prefix_mask.bool()
150
  if not return_dict:
@@ -152,8 +171,8 @@ class MPTModel(MPTPreTrainedModel):
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
- raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
  if self.training:
@@ -161,9 +180,10 @@ class MPTModel(MPTPreTrainedModel):
161
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
162
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
163
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
164
- S = input_ids.size(1)
 
165
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
166
- tok_emb = self.wte(input_ids)
167
  if self.alibi:
168
  x = tok_emb
169
  else:
@@ -177,7 +197,8 @@ class MPTModel(MPTPreTrainedModel):
177
  if S + past_position > self.config.max_seq_len:
178
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
179
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
180
- if attention_mask is not None:
 
181
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
182
  pos_emb = self.wpe(pos)
183
  x = tok_emb + pos_emb
@@ -259,10 +280,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
259
  def get_decoder(self):
260
  return self.transformer
261
 
262
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
 
263
  return_dict = return_dict if return_dict is not None else self.config.return_dict
264
  use_cache = use_cache if use_cache is not None else self.config.use_cache
265
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
266
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
267
  if self.logit_scale is not None:
268
  if self.logit_scale == 0:
 
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
+ # def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
144
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
145
  return_dict = return_dict if return_dict is not None else self.config.return_dict
146
  use_cache = use_cache if use_cache is not None else self.config.use_cache
147
+ if input_ids is not None and inputs_embeds is not None:
148
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
149
+ elif input_ids is not None:
150
+ batch_size, seq_length = input_ids.shape
151
+ elif inputs_embeds is not None:
152
+ batch_size, seq_length, _ = inputs_embeds.shape
153
+ else:
154
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
155
  if attention_mask is not None:
156
  attention_mask = attention_mask.bool()
157
+ else:
158
+ attention_mask = torch.ones(
159
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
160
+ )
161
+
162
+ if inputs_embeds is None:
163
+ tok_emb = self.wte(input_ids)
164
+ else:
165
+ tok_emb = inputs_embeds
166
+
167
  if prefix_mask is not None:
168
  prefix_mask = prefix_mask.bool()
169
  if not return_dict:
 
171
  if output_attentions:
172
  if self.attn_impl != 'torch':
173
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
174
+ # if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
175
+ # raise NotImplementedError('MPT does not support training with left padding.')
176
  if self.prefix_lm and prefix_mask is None:
177
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
178
  if self.training:
 
180
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
181
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
182
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
183
+ # S = input_ids.size(1)
184
+ S = seq_length
185
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
186
+ # tok_emb = self.wte(input_ids)
187
  if self.alibi:
188
  x = tok_emb
189
  else:
 
197
  if S + past_position > self.config.max_seq_len:
198
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
199
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
200
+ # if attention_mask is not None :
201
+ if attention_mask is not None and not self.training:
202
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
203
  pos_emb = self.wpe(pos)
204
  x = tok_emb + pos_emb
 
280
  def get_decoder(self):
281
  return self.transformer
282
 
283
+ # def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
284
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
285
  return_dict = return_dict if return_dict is not None else self.config.return_dict
286
  use_cache = use_cache if use_cache is not None else self.config.use_cache
287
+ # outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
288
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
289
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
290
  if self.logit_scale is not None:
291
  if self.logit_scale == 0: