Upload models
Browse files- modeling_internimage.py +4 -4
modeling_internimage.py
CHANGED
@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
|
|
853 |
remove_center=config.remove_center, # for InternImage-H/G
|
854 |
)
|
855 |
|
856 |
-
def forward(self,
|
857 |
-
return self.model.forward_features(
|
858 |
|
859 |
|
860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
|
|
888 |
remove_center=config.remove_center, # for InternImage-H/G
|
889 |
)
|
890 |
|
891 |
-
def forward(self,
|
892 |
-
outputs = self.model.forward(
|
893 |
|
894 |
if labels is not None:
|
895 |
logits = outputs['logits']
|
|
|
853 |
remove_center=config.remove_center, # for InternImage-H/G
|
854 |
)
|
855 |
|
856 |
+
def forward(self, pixel_values):
|
857 |
+
return self.model.forward_features(pixel_values)
|
858 |
|
859 |
|
860 |
class InternImageModelForImageClassification(PreTrainedModel):
|
|
|
888 |
remove_center=config.remove_center, # for InternImage-H/G
|
889 |
)
|
890 |
|
891 |
+
def forward(self, pixel_values, labels=None):
|
892 |
+
outputs = self.model.forward(pixel_values)
|
893 |
|
894 |
if labels is not None:
|
895 |
logits = outputs['logits']
|