jatingocodeo commited on
Commit
6c13380
·
verified ·
1 Parent(s): e061e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -150,7 +150,9 @@ class SmolLM2ForCausalLM(PreTrainedModel):
150
  if config.tie_word_embeddings:
151
  self.lm_head.weight = self.embed_tokens.weight
152
 
153
- def forward(self, input_ids, attention_mask=None, labels=None):
 
 
154
  hidden_states = self.embed_tokens(input_ids)
155
 
156
  # Create causal attention mask if none provided
@@ -171,8 +173,18 @@ class SmolLM2ForCausalLM(PreTrainedModel):
171
  loss = None
172
  if labels is not None:
173
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
174
-
175
- return logits if loss is None else (loss, logits)
 
 
 
 
 
 
 
 
 
 
176
 
177
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
178
  return {
 
150
  if config.tie_word_embeddings:
151
  self.lm_head.weight = self.embed_tokens.weight
152
 
153
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=None, **kwargs):
154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
155
+
156
  hidden_states = self.embed_tokens(input_ids)
157
 
158
  # Create causal attention mask if none provided
 
173
  loss = None
174
  if labels is not None:
175
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
176
+
177
+ if return_dict:
178
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
179
+ return CausalLMOutputWithCrossAttentions(
180
+ loss=loss,
181
+ logits=logits,
182
+ past_key_values=None,
183
+ hidden_states=None,
184
+ attentions=None,
185
+ cross_attentions=None,
186
+ )
187
+ return (loss, logits) if loss is not None else logits
188
 
189
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
190
  return {