sachin commited on
Commit
c7a14ad
·
1 Parent(s): a8c8fe0

Modified config to transformers standard

Browse files
Files changed (1) hide show
  1. src/config.py +69 -0
src/config.py CHANGED
@@ -1,6 +1,7 @@
1
  import pathlib
2
 
3
  import pydantic
 
4
 
5
  MAX_DOWNLOAD_TIME = 0.2
6
 
@@ -16,6 +17,74 @@ class DataConfig(pydantic.BaseModel):
16
  dataset: str = small_dataset
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class ModelConfig(pydantic.BaseModel):
20
  text_model: str = "microsoft/xtremedistil-l6-h256-uncased" # 51 mb
21
  vision_model: str = "edgenext_small" # 20 mb
 
1
  import pathlib
2
 
3
  import pydantic
4
+ from transformers import PretrainedConfig
5
 
6
  MAX_DOWNLOAD_TIME = 0.2
7
 
 
17
  dataset: str = small_dataset
18
 
19
 
20
+ class TinyCLIPTextConfig(PretrainedConfig):
21
+ model_type = "text"
22
+
23
+ def __init__(
24
+ self,
25
+ text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
26
+ projection_layers: int = 3,
27
+ embed_dims: int = 512,
28
+ max_len: int = 128,
29
+ cls_type: bool = True,
30
+ **kwargs,
31
+ ):
32
+ self.text_model = text_model
33
+ self.projection_layers = projection_layers
34
+ self.embed_dims = embed_dims
35
+ self.max_len = max_len
36
+ self.cls_type = cls_type
37
+ super().__init__(**kwargs)
38
+
39
+
40
+ class TinyCLIPVisionConfig(PretrainedConfig):
41
+ model_type = "vision"
42
+
43
+ def __init__(
44
+ self,
45
+ vision_model: str = "edgenext_small",
46
+ projection_layers: int = 3,
47
+ embed_dims: int = 512,
48
+ **kwargs,
49
+ ):
50
+ self.vision_model = vision_model
51
+ self.projection_layers = projection_layers
52
+ self.embed_dims = embed_dims
53
+ super().__init__(**kwargs)
54
+
55
+
56
+ class TinyCLIPConfig(PretrainedConfig):
57
+ model_type = "clip"
58
+
59
+ def __init__(
60
+ self,
61
+ text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
62
+ vision_model: str = "edgenext_small",
63
+ projection_layers: int = 3,
64
+ embed_dim: int = 512,
65
+ max_len: int = 128,
66
+ cls_type: bool = True,
67
+ freeze_vision_base: bool = False,
68
+ freeze_text_base: bool = False,
69
+ loss_type: str = "cyclip",
70
+ **kwargs,
71
+ ):
72
+ self.text_config = TinyCLIPTextConfig(
73
+ text_model=text_model,
74
+ projection_layers=projection_layers,
75
+ embed_dims=embed_dim,
76
+ max_len=max_len,
77
+ cls_type=cls_type,
78
+ )
79
+ self.vision_config = TinyCLIPVisionConfig(
80
+ vision_model=vision_model, projection_layers=projection_layers, embed_dims=embed_dim
81
+ )
82
+ self.freeze_vision_base = freeze_vision_base
83
+ self.freeze_text_base = freeze_text_base
84
+ self.loss_type = loss_type
85
+ super().__init__(**kwargs)
86
+
87
+
88
  class ModelConfig(pydantic.BaseModel):
89
  text_model: str = "microsoft/xtremedistil-l6-h256-uncased" # 51 mb
90
  vision_model: str = "edgenext_small" # 20 mb