nicolinho commited on
Commit
83d5ac5
·
verified ·
1 Parent(s): 56f4005

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +1 -1
modeling_custom.py CHANGED
@@ -147,7 +147,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
147
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
148
  """
149
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
- if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,0] == 2:
151
  input_ids = input_ids[:, 1:]
152
  if attention_mask is not None:
153
  attention_mask = attention_mask[:, 1:]
 
147
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
148
  """
149
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
+ if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,1] == 2:
151
  input_ids = input_ids[:, 1:]
152
  if attention_mask is not None:
153
  attention_mask = attention_mask[:, 1:]