Update modeling_custom.py
Browse files- modeling_custom.py +1 -1
modeling_custom.py
CHANGED
@@ -147,7 +147,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
|
|
147 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
148 |
"""
|
149 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
150 |
-
if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,
|
151 |
input_ids = input_ids[:, 1:]
|
152 |
if attention_mask is not None:
|
153 |
attention_mask = attention_mask[:, 1:]
|
|
|
147 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
148 |
"""
|
149 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
150 |
+
if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,1] == 2:
|
151 |
input_ids = input_ids[:, 1:]
|
152 |
if attention_mask is not None:
|
153 |
attention_mask = attention_mask[:, 1:]
|