Update modeling_internimage.py
#2
by
parakh01
- opened
- modeling_internimage.py +55 -15
modeling_internimage.py
CHANGED
@@ -800,23 +800,31 @@ class InternImage(nn.Module):
|
|
800 |
'pooler_output': x if self.num_classes > 0 else None
|
801 |
}
|
802 |
|
803 |
-
def forward(self,
|
|
|
|
|
|
|
|
|
804 |
if self.use_clip_projector: # for InternImage-H/G
|
805 |
-
outputs = self.forward_clip_projector(
|
806 |
else: # for InternImage-T/S/B/L/XL
|
807 |
-
outputs = self.forward_features(
|
808 |
|
809 |
-
hidden_states = outputs['hidden_states']
|
810 |
-
pooler_output = outputs['pooler_output']
|
|
|
811 |
|
812 |
if self.num_classes > 0:
|
813 |
-
logits = self.head(pooler_output)
|
814 |
else:
|
815 |
logits = None
|
816 |
|
|
|
|
|
|
|
817 |
return BackboneOutput(
|
818 |
hidden_states=hidden_states,
|
819 |
-
last_hidden_state=
|
820 |
pooler_output=pooler_output,
|
821 |
logits=logits
|
822 |
)
|
@@ -853,8 +861,17 @@ class InternImageModel(PreTrainedModel):
|
|
853 |
remove_center=config.remove_center, # for InternImage-H/G
|
854 |
)
|
855 |
|
856 |
-
def forward(self,
|
857 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
858 |
|
859 |
|
860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
@@ -862,6 +879,7 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
862 |
|
863 |
def __init__(self, config):
|
864 |
super().__init__(config)
|
|
|
865 |
self.model = InternImage(
|
866 |
core_op=config.core_op,
|
867 |
channels=config.channels,
|
@@ -888,12 +906,34 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
888 |
remove_center=config.remove_center, # for InternImage-H/G
|
889 |
)
|
890 |
|
891 |
-
def forward(self,
|
892 |
-
|
893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
894 |
if labels is not None:
|
895 |
-
logits = outputs[
|
896 |
loss = F.cross_entropy(logits, labels)
|
897 |
-
outputs['loss'] = loss
|
898 |
|
899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
'pooler_output': x if self.num_classes > 0 else None
|
801 |
}
|
802 |
|
803 |
+
def forward(self,
|
804 |
+
pixel_values,
|
805 |
+
output_attentions=None,
|
806 |
+
output_hidden_states=None,
|
807 |
+
return_dict=None):
|
808 |
if self.use_clip_projector: # for InternImage-H/G
|
809 |
+
outputs = self.forward_clip_projector(pixel_values)
|
810 |
else: # for InternImage-T/S/B/L/XL
|
811 |
+
outputs = self.forward_features(pixel_values)
|
812 |
|
813 |
+
hidden_states = outputs['hidden_states'] if output_hidden_states is not None else None
|
814 |
+
pooler_output = outputs['pooler_output'] if output_attentions is not None else None
|
815 |
+
last_hidden_state = outputs['hidden_states'][-1] if output_hidden_states is not None else None
|
816 |
|
817 |
if self.num_classes > 0:
|
818 |
+
logits = self.head(outputs['pooler_output'])
|
819 |
else:
|
820 |
logits = None
|
821 |
|
822 |
+
if not return_dict:
|
823 |
+
return tuple(v for v in [logits, hidden_states, pooler_output, last_hidden_state] if v is not None)
|
824 |
+
|
825 |
return BackboneOutput(
|
826 |
hidden_states=hidden_states,
|
827 |
+
last_hidden_state=last_hidden_state,
|
828 |
pooler_output=pooler_output,
|
829 |
logits=logits
|
830 |
)
|
|
|
861 |
remove_center=config.remove_center, # for InternImage-H/G
|
862 |
)
|
863 |
|
864 |
+
def forward(self,
|
865 |
+
pixel_values,
|
866 |
+
output_attentions=None,
|
867 |
+
output_hidden_states=None,
|
868 |
+
return_dict=None):
|
869 |
+
|
870 |
+
return self.model.forward_features(
|
871 |
+
pixel_values,
|
872 |
+
output_attentions=output_attentions,
|
873 |
+
output_hidden_states=output_hidden_states,
|
874 |
+
return_dict=return_dict)
|
875 |
|
876 |
|
877 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
|
879 |
|
880 |
def __init__(self, config):
|
881 |
super().__init__(config)
|
882 |
+
self.config = config
|
883 |
self.model = InternImage(
|
884 |
core_op=config.core_op,
|
885 |
channels=config.channels,
|
|
|
906 |
remove_center=config.remove_center, # for InternImage-H/G
|
907 |
)
|
908 |
|
909 |
+
def forward(self,
|
910 |
+
pixel_values,
|
911 |
+
labels=None,
|
912 |
+
output_attentions=None,
|
913 |
+
output_hidden_states=None,
|
914 |
+
return_dict=None):
|
915 |
+
|
916 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
917 |
+
|
918 |
+
outputs = self.model.forward(
|
919 |
+
pixel_values,
|
920 |
+
output_attentions=output_attentions,
|
921 |
+
output_hidden_states=output_hidden_states,
|
922 |
+
return_dict=return_dict)
|
923 |
+
|
924 |
+
loss = None
|
925 |
if labels is not None:
|
926 |
+
logits = outputs.logits if return_dict else outputs[0]
|
927 |
loss = F.cross_entropy(logits, labels)
|
|
|
928 |
|
929 |
+
if not return_dict:
|
930 |
+
output = (outputs[0],) + outputs[1:]
|
931 |
+
return ((loss,) + output) if loss is not None else output
|
932 |
+
|
933 |
+
return BackboneOutput(
|
934 |
+
loss = loss,
|
935 |
+
logits = outputs.logits,
|
936 |
+
hidden_states = outputs.hidden_states,
|
937 |
+
last_hidden_state = outputs.last_hidden_state,
|
938 |
+
pooler_output = outputs.pooler_output
|
939 |
+
)
|