amaye15 commited on
Commit
c386892
·
verified ·
1 Parent(s): d2f7259

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +8 -4
modeling_aimv2.py CHANGED
@@ -222,7 +222,7 @@ class AIMv2Model(AIMv2PretrainedModel):
222
  hidden_states=hidden_states,
223
  )
224
 
225
- '''
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
@@ -271,6 +271,7 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
271
 
272
  loss = None
273
  if labels is not None:
 
274
  # move labels to correct device to enable model parallelism
275
  labels = labels.to(logits.device)
276
  if self.config.problem_type is None:
@@ -295,7 +296,10 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
295
  elif self.config.problem_type == "multi_label_classification":
296
  loss_fct = BCEWithLogitsLoss()
297
  loss = loss_fct(logits, labels)
298
-
 
 
 
299
  if not return_dict:
300
  output = (logits,) + outputs[1:]
301
  return ((loss,) + output) if loss is not None else output
@@ -306,9 +310,9 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
306
  hidden_states=outputs.hidden_states,
307
  # attentions=outputs.attentions,
308
  )
309
- '''
310
 
311
 
 
312
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
313
  def __init__(self, config: AIMv2Config):
314
  super().__init__(config)
@@ -381,4 +385,4 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
381
  logits=logits,
382
  hidden_states=outputs.hidden_states,
383
  )
384
-
 
222
  hidden_states=hidden_states,
223
  )
224
 
225
+
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
 
271
 
272
  loss = None
273
  if labels is not None:
274
+ print("LABELS: ", labels)
275
  # move labels to correct device to enable model parallelism
276
  labels = labels.to(logits.device)
277
  if self.config.problem_type is None:
 
296
  elif self.config.problem_type == "multi_label_classification":
297
  loss_fct = BCEWithLogitsLoss()
298
  loss = loss_fct(logits, labels)
299
+
300
+ print("PROBLEM", self.config.problem_type)
301
+ print("LOSS: ", loss)
302
+
303
  if not return_dict:
304
  output = (logits,) + outputs[1:]
305
  return ((loss,) + output) if loss is not None else output
 
310
  hidden_states=outputs.hidden_states,
311
  # attentions=outputs.attentions,
312
  )
 
313
 
314
 
315
+ '''
316
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
317
  def __init__(self, config: AIMv2Config):
318
  super().__init__(config)
 
385
  logits=logits,
386
  hidden_states=outputs.hidden_states,
387
  )
388
+ '''