Update modeling_aimv2.py
Browse files- modeling_aimv2.py +8 -4
modeling_aimv2.py
CHANGED
@@ -222,7 +222,7 @@ class AIMv2Model(AIMv2PretrainedModel):
|
|
222 |
hidden_states=hidden_states,
|
223 |
)
|
224 |
|
225 |
-
|
226 |
class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
227 |
def __init__(self, config: AIMv2Config):
|
228 |
super().__init__(config)
|
@@ -271,6 +271,7 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
|
271 |
|
272 |
loss = None
|
273 |
if labels is not None:
|
|
|
274 |
# move labels to correct device to enable model parallelism
|
275 |
labels = labels.to(logits.device)
|
276 |
if self.config.problem_type is None:
|
@@ -295,7 +296,10 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
|
295 |
elif self.config.problem_type == "multi_label_classification":
|
296 |
loss_fct = BCEWithLogitsLoss()
|
297 |
loss = loss_fct(logits, labels)
|
298 |
-
|
|
|
|
|
|
|
299 |
if not return_dict:
|
300 |
output = (logits,) + outputs[1:]
|
301 |
return ((loss,) + output) if loss is not None else output
|
@@ -306,9 +310,9 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
|
306 |
hidden_states=outputs.hidden_states,
|
307 |
# attentions=outputs.attentions,
|
308 |
)
|
309 |
-
'''
|
310 |
|
311 |
|
|
|
312 |
class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
313 |
def __init__(self, config: AIMv2Config):
|
314 |
super().__init__(config)
|
@@ -381,4 +385,4 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
|
381 |
logits=logits,
|
382 |
hidden_states=outputs.hidden_states,
|
383 |
)
|
384 |
-
|
|
|
222 |
hidden_states=hidden_states,
|
223 |
)
|
224 |
|
225 |
+
|
226 |
class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
227 |
def __init__(self, config: AIMv2Config):
|
228 |
super().__init__(config)
|
|
|
271 |
|
272 |
loss = None
|
273 |
if labels is not None:
|
274 |
+
print("LABELS: ", labels)
|
275 |
# move labels to correct device to enable model parallelism
|
276 |
labels = labels.to(logits.device)
|
277 |
if self.config.problem_type is None:
|
|
|
296 |
elif self.config.problem_type == "multi_label_classification":
|
297 |
loss_fct = BCEWithLogitsLoss()
|
298 |
loss = loss_fct(logits, labels)
|
299 |
+
|
300 |
+
print("PROBLEM", self.config.problem_type)
|
301 |
+
print("LOSS: ", loss)
|
302 |
+
|
303 |
if not return_dict:
|
304 |
output = (logits,) + outputs[1:]
|
305 |
return ((loss,) + output) if loss is not None else output
|
|
|
310 |
hidden_states=outputs.hidden_states,
|
311 |
# attentions=outputs.attentions,
|
312 |
)
|
|
|
313 |
|
314 |
|
315 |
+
'''
|
316 |
class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
317 |
def __init__(self, config: AIMv2Config):
|
318 |
super().__init__(config)
|
|
|
385 |
logits=logits,
|
386 |
hidden_states=outputs.hidden_states,
|
387 |
)
|
388 |
+
'''
|