Astaxanthin commited on
Commit
06471d0
·
verified ·
1 Parent(s): 9ab05d7

Upload 6 files

Browse files
Files changed (4) hide show
  1. config.json +35 -26
  2. model.safetensors +2 -2
  3. modeling_keep.py +29 -17
  4. pytorch_model.bin +1 -1
config.json CHANGED
@@ -1,34 +1,43 @@
1
  {
2
- "model_type": "keep",
3
- "vision_config": {
4
- "model_type": "vit",
5
- "img_size": 224,
6
- "patch_size": 16,
7
- "hidden_size": 1024,
8
- "num_heads": 16,
9
- "num_layers": 24,
10
- "mlp_ratio": 4.0,
11
- "qkv_bias": true,
12
- "drop_rate": 0.0,
13
- "attn_drop_rate": 0.0,
14
- "init_values": 1e-5,
15
- "num_classes": 0,
16
- "dynamic_img_size": true
17
  },
 
 
18
  "text_config": {
19
- "model_type": "bert",
20
- "vocab_size": 30522,
21
- "hidden_size": 768,
22
- "num_hidden_layers": 12,
23
- "num_attention_heads": 12,
24
- "intermediate_size": 3072,
25
  "hidden_act": "gelu",
26
  "hidden_dropout_prob": 0.1,
27
- "attention_probs_dropout_prob": 0.1,
 
 
 
28
  "max_position_embeddings": 512,
 
 
 
29
  "type_vocab_size": 2,
30
- "initializer_range": 0.02,
31
- "layer_norm_eps": 1e-12
32
  },
33
- "projection_dim": 768
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "architectures": [
3
+ "KEEPModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_keep.KEEPConfig",
7
+ "AutoModel": "modeling_keep.KEEPModel"
 
 
 
 
 
 
 
 
 
8
  },
9
+ "model_type": "keep",
10
+ "projection_dim": 768,
11
  "text_config": {
12
+ "attention_probs_dropout_prob": 0.1,
 
 
 
 
 
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 768,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 3072,
18
+ "layer_norm_eps": 1e-12,
19
  "max_position_embeddings": 512,
20
+ "model_type": "bert",
21
+ "num_attention_heads": 12,
22
+ "num_hidden_layers": 12,
23
  "type_vocab_size": 2,
24
+ "vocab_size": 30522
 
25
  },
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.46.3",
28
+ "vision_config": {
29
+ "attn_drop_rate": 0.0,
30
+ "drop_rate": 0.0,
31
+ "dynamic_img_size": true,
32
+ "hidden_size": 1024,
33
+ "img_size": 224,
34
+ "init_values": 1e-05,
35
+ "mlp_ratio": 4.0,
36
+ "model_type": "vit",
37
+ "num_classes": 0,
38
+ "num_heads": 16,
39
+ "num_layers": 24,
40
+ "patch_size": 16,
41
+ "qkv_bias": true
42
+ }
43
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f8ed817b279417d3c67a842477fcae056212eda4e24350c860fe2ee70d9623fc
3
- size 1656902036
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82f610d5359aca67b5fd5d841009f26db430ae78d0693743589c0a727b0a146d
3
+ size 1656902084
modeling_keep.py CHANGED
@@ -6,48 +6,61 @@ import numpy
6
  from torchvision import transforms
7
  from PIL import Image
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class KEEPConfig(PretrainedConfig):
10
- model_type = "keep" # 标记模型类型
11
 
12
  def __init__(
13
  self,
14
- vision_config=None, # Vision Encoder 的配置
15
- text_config=None, # Text Encoder 的配置
16
- projection_dim=768, # 投影维度,默认为 768
17
  **kwargs,
18
  ):
19
  super().__init__(**kwargs)
20
  self.vision_config = vision_config
21
  self.text_config = text_config
22
  self.projection_dim = projection_dim
23
-
24
 
25
  class KEEPModel(PreTrainedModel):
26
- config_class = KEEPConfig # 绑定到自定义配置类
27
 
28
  def __init__(self, config):
29
  super().__init__(config)
30
 
31
- # Vision Encoder (基于 timm 的 ViT)
32
  vision_config = config.vision_config
33
  self.visual = timm.create_model(
34
  "vit_large_patch16_224",
35
  pretrained=False,
36
- img_size=vision_config.get("img_size", 224),
37
- patch_size=vision_config.get("patch_size", 16),
38
- init_values=vision_config.get("init_values", 1e-5),
39
- num_classes=vision_config.get("num_classes", 0),
40
- dynamic_img_size=vision_config.get("dynamic_img_size", True),
41
  )
42
 
43
- # 线性投影层,将 Vision Encoder 的输出投影到 768 维
44
  self.visual_head = nn.Sequential(
45
  nn.Linear(self.visual.num_features, config.projection_dim),
46
  nn.GELU(),
47
  nn.Linear(config.projection_dim, config.projection_dim)
48
  )
49
 
50
- # Text Encoder (基于 PubMedBERT)
51
  text_config = BertConfig(**config.text_config)
52
  self.text = BertModel(text_config)
53
 
@@ -69,8 +82,7 @@ class KEEPModel(PreTrainedModel):
69
 
70
  text_features = self.encode_text(text_inputs)
71
 
72
- # 返回两个独立的特征
73
  return {
74
- "vision_features": vision_features, # 视觉特征
75
- "text_features": text_features # 文本特征
76
  }
 
6
  from torchvision import transforms
7
  from PIL import Image
8
 
9
+ class RenameLayerScale(nn.Module):
10
+ def __init__(
11
+ self,
12
+ dim: int,
13
+ init_values: float = 1e-5,
14
+ inplace: bool = False,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.inplace = inplace
18
+ self.weight = nn.Parameter(init_values * torch.ones(dim))
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ return x.mul_(self.weight) if self.inplace else x * self.weight
22
+
23
+ timm.models.vision_transformer.LayerScale = RenameLayerScale
24
+
25
  class KEEPConfig(PretrainedConfig):
26
+ model_type = "keep" #
27
 
28
  def __init__(
29
  self,
30
+ vision_config=None, # Vision Encoder
31
+ text_config=None, # Text Encoder
32
+ projection_dim=768,
33
  **kwargs,
34
  ):
35
  super().__init__(**kwargs)
36
  self.vision_config = vision_config
37
  self.text_config = text_config
38
  self.projection_dim = projection_dim
 
39
 
40
  class KEEPModel(PreTrainedModel):
41
+ config_class = KEEPConfig #
42
 
43
  def __init__(self, config):
44
  super().__init__(config)
45
 
46
+ # Vision Encoder
47
  vision_config = config.vision_config
48
  self.visual = timm.create_model(
49
  "vit_large_patch16_224",
50
  pretrained=False,
51
+ img_size=vision_config["img_size"],
52
+ patch_size=vision_config["patch_size"],
53
+ init_values=vision_config["init_values"],
54
+ num_classes=vision_config["num_classes"],
 
55
  )
56
 
 
57
  self.visual_head = nn.Sequential(
58
  nn.Linear(self.visual.num_features, config.projection_dim),
59
  nn.GELU(),
60
  nn.Linear(config.projection_dim, config.projection_dim)
61
  )
62
 
63
+ # Text Encoder
64
  text_config = BertConfig(**config.text_config)
65
  self.text = BertModel(text_config)
66
 
 
82
 
83
  text_features = self.encode_text(text_inputs)
84
 
 
85
  return {
86
+ "vision_features": vision_features,
87
+ "text_features": text_features
88
  }
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:19a9ef805fcde4f1a255892ed755f960214fb19da59e87d2fc0de49d4683946b
3
  size 1657016149
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:526a677bf714388d2485a45f5c372505a9874d56a86645b154e2d46ab60d87ca
3
  size 1657016149