Astaxanthin commited on
Commit
9ab05d7
·
verified ·
1 Parent(s): c88c3a0

Upload modeling_keep.py

Browse files
Files changed (1) hide show
  1. modeling_keep.py +6 -5
modeling_keep.py CHANGED
@@ -29,14 +29,15 @@ class KEEPModel(PreTrainedModel):
29
  super().__init__(config)
30
 
31
  # Vision Encoder (基于 timm 的 ViT)
 
32
  self.visual = timm.create_model(
33
  "vit_large_patch16_224",
34
  pretrained=False,
35
- img_size=224,
36
- patch_size=16,
37
- init_values=1e-5,
38
- num_classes=0,
39
- dynamic_img_size=True,
40
  )
41
 
42
  # 线性投影层,将 Vision Encoder 的输出投影到 768 维
 
29
  super().__init__(config)
30
 
31
  # Vision Encoder (基于 timm 的 ViT)
32
+ vision_config = config.vision_config
33
  self.visual = timm.create_model(
34
  "vit_large_patch16_224",
35
  pretrained=False,
36
+ img_size=vision_config.get("img_size", 224),
37
+ patch_size=vision_config.get("patch_size", 16),
38
+ init_values=vision_config.get("init_values", 1e-5),
39
+ num_classes=vision_config.get("num_classes", 0),
40
+ dynamic_img_size=vision_config.get("dynamic_img_size", True),
41
  )
42
 
43
  # 线性投影层,将 Vision Encoder 的输出投影到 768 维