Upload modeling_keep.py
Browse files- modeling_keep.py +6 -5
modeling_keep.py
CHANGED
@@ -29,14 +29,15 @@ class KEEPModel(PreTrainedModel):
|
|
29 |
super().__init__(config)
|
30 |
|
31 |
# Vision Encoder (基于 timm 的 ViT)
|
|
|
32 |
self.visual = timm.create_model(
|
33 |
"vit_large_patch16_224",
|
34 |
pretrained=False,
|
35 |
-
img_size=224,
|
36 |
-
patch_size=16,
|
37 |
-
init_values=1e-5,
|
38 |
-
num_classes=0,
|
39 |
-
dynamic_img_size=True,
|
40 |
)
|
41 |
|
42 |
# 线性投影层,将 Vision Encoder 的输出投影到 768 维
|
|
|
29 |
super().__init__(config)
|
30 |
|
31 |
# Vision Encoder (基于 timm 的 ViT)
|
32 |
+
vision_config = config.vision_config
|
33 |
self.visual = timm.create_model(
|
34 |
"vit_large_patch16_224",
|
35 |
pretrained=False,
|
36 |
+
img_size=vision_config.get("img_size", 224),
|
37 |
+
patch_size=vision_config.get("patch_size", 16),
|
38 |
+
init_values=vision_config.get("init_values", 1e-5),
|
39 |
+
num_classes=vision_config.get("num_classes", 0),
|
40 |
+
dynamic_img_size=vision_config.get("dynamic_img_size", True),
|
41 |
)
|
42 |
|
43 |
# 线性投影层,将 Vision Encoder 的输出投影到 768 维
|