Update modeling_internimage.py

#2
by parakh01 - opened
Files changed (1) hide show
  1. modeling_internimage.py +55 -15
modeling_internimage.py CHANGED
@@ -800,23 +800,31 @@ class InternImage(nn.Module):
800
  'pooler_output': x if self.num_classes > 0 else None
801
  }
802
 
803
- def forward(self, x):
 
 
 
 
804
  if self.use_clip_projector: # for InternImage-H/G
805
- outputs = self.forward_clip_projector(x)
806
  else: # for InternImage-T/S/B/L/XL
807
- outputs = self.forward_features(x)
808
 
809
- hidden_states = outputs['hidden_states']
810
- pooler_output = outputs['pooler_output']
 
811
 
812
  if self.num_classes > 0:
813
- logits = self.head(pooler_output)
814
  else:
815
  logits = None
816
 
 
 
 
817
  return BackboneOutput(
818
  hidden_states=hidden_states,
819
- last_hidden_state=hidden_states[-1],
820
  pooler_output=pooler_output,
821
  logits=logits
822
  )
@@ -853,8 +861,17 @@ class InternImageModel(PreTrainedModel):
853
  remove_center=config.remove_center, # for InternImage-H/G
854
  )
855
 
856
- def forward(self, pixel_values):
857
- return self.model.forward_features(pixel_values)
 
 
 
 
 
 
 
 
 
858
 
859
 
860
  class InternImageModelForImageClassification(PreTrainedModel):
@@ -862,6 +879,7 @@ class InternImageModelForImageClassification(PreTrainedModel):
862
 
863
  def __init__(self, config):
864
  super().__init__(config)
 
865
  self.model = InternImage(
866
  core_op=config.core_op,
867
  channels=config.channels,
@@ -888,12 +906,34 @@ class InternImageModelForImageClassification(PreTrainedModel):
888
  remove_center=config.remove_center, # for InternImage-H/G
889
  )
890
 
891
- def forward(self, pixel_values, labels=None):
892
- outputs = self.model.forward(pixel_values)
893
-
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  if labels is not None:
895
- logits = outputs['logits']
896
  loss = F.cross_entropy(logits, labels)
897
- outputs['loss'] = loss
898
 
899
- return outputs
 
 
 
 
 
 
 
 
 
 
 
800
  'pooler_output': x if self.num_classes > 0 else None
801
  }
802
 
803
+ def forward(self,
804
+ pixel_values,
805
+ output_attentions=None,
806
+ output_hidden_states=None,
807
+ return_dict=None):
808
  if self.use_clip_projector: # for InternImage-H/G
809
+ outputs = self.forward_clip_projector(pixel_values)
810
  else: # for InternImage-T/S/B/L/XL
811
+ outputs = self.forward_features(pixel_values)
812
 
813
+ hidden_states = outputs['hidden_states'] if output_hidden_states is not None else None
814
+ pooler_output = outputs['pooler_output'] if output_attentions is not None else None
815
+ last_hidden_state = outputs['hidden_states'][-1] if output_hidden_states is not None else None
816
 
817
  if self.num_classes > 0:
818
+ logits = self.head(outputs['pooler_output'])
819
  else:
820
  logits = None
821
 
822
+ if not return_dict:
823
+ return tuple(v for v in [logits, hidden_states, pooler_output, last_hidden_state] if v is not None)
824
+
825
  return BackboneOutput(
826
  hidden_states=hidden_states,
827
+ last_hidden_state=last_hidden_state,
828
  pooler_output=pooler_output,
829
  logits=logits
830
  )
 
861
  remove_center=config.remove_center, # for InternImage-H/G
862
  )
863
 
864
+ def forward(self,
865
+ pixel_values,
866
+ output_attentions=None,
867
+ output_hidden_states=None,
868
+ return_dict=None):
869
+
870
+ return self.model.forward_features(
871
+ pixel_values,
872
+ output_attentions=output_attentions,
873
+ output_hidden_states=output_hidden_states,
874
+ return_dict=return_dict)
875
 
876
 
877
  class InternImageModelForImageClassification(PreTrainedModel):
 
879
 
880
  def __init__(self, config):
881
  super().__init__(config)
882
+ self.config = config
883
  self.model = InternImage(
884
  core_op=config.core_op,
885
  channels=config.channels,
 
906
  remove_center=config.remove_center, # for InternImage-H/G
907
  )
908
 
909
+ def forward(self,
910
+ pixel_values,
911
+ labels=None,
912
+ output_attentions=None,
913
+ output_hidden_states=None,
914
+ return_dict=None):
915
+
916
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
917
+
918
+ outputs = self.model.forward(
919
+ pixel_values,
920
+ output_attentions=output_attentions,
921
+ output_hidden_states=output_hidden_states,
922
+ return_dict=return_dict)
923
+
924
+ loss = None
925
  if labels is not None:
926
+ logits = outputs.logits if return_dict else outputs[0]
927
  loss = F.cross_entropy(logits, labels)
 
928
 
929
+ if not return_dict:
930
+ output = (outputs[0],) + outputs[1:]
931
+ return ((loss,) + output) if loss is not None else output
932
+
933
+ return BackboneOutput(
934
+ loss = loss,
935
+ logits = outputs.logits,
936
+ hidden_states = outputs.hidden_states,
937
+ last_hidden_state = outputs.last_hidden_state,
938
+ pooler_output = outputs.pooler_output
939
+ )