jackieeepeng commited on
Commit
e293f23
·
verified ·
1 Parent(s): a1ca341

Upload model

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. diy_model.py +64 -0
config.json CHANGED
@@ -3,7 +3,8 @@
3
  "ResnetModelForImageClassification"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "diy_config.ResnetConfig"
 
7
  },
8
  "avg_down": true,
9
  "base_width": 64,
 
3
  "ResnetModelForImageClassification"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "diy_config.ResnetConfig",
7
+ "AutoModelForImageClassification": "diy_model.ResnetModelForImageClassification"
8
  },
9
  "avg_down": true,
10
  "base_width": 64,
diy_model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
+ from diy_config import ResnetConfig
4
+
5
+
6
+ BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
7
+
8
+
9
+ class ResnetModel(PreTrainedModel):
10
+ config_class = ResnetConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ block_layer = BLOCK_MAPPING[config.block_type]
15
+ self.model = ResNet(
16
+ block_layer,
17
+ config.layers,
18
+ num_classes=config.num_classes,
19
+ in_chans=config.input_channels,
20
+ cardinality=config.cardinality,
21
+ base_width=config.base_width,
22
+ stem_width=config.stem_width,
23
+ stem_type=config.stem_type,
24
+ avg_down=config.avg_down,
25
+ )
26
+
27
+ def forward(self, tensor):
28
+ return self.model.forward_features(tensor)
29
+
30
+ import torch
31
+
32
+
33
+ class ResnetModelForImageClassification(PreTrainedModel):
34
+ config_class = ResnetConfig
35
+
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+ block_layer = BLOCK_MAPPING[config.block_type]
39
+ self.model = ResNet(
40
+ block_layer,
41
+ config.layers,
42
+ num_classes=config.num_classes,
43
+ in_chans=config.input_channels,
44
+ cardinality=config.cardinality,
45
+ base_width=config.base_width,
46
+ stem_width=config.stem_width,
47
+ stem_type=config.stem_type,
48
+ avg_down=config.avg_down,
49
+ )
50
+
51
+ def forward(self, tensor, labels=None):
52
+ logits = self.model(tensor)
53
+ if labels is not None:
54
+ loss = torch.nn.functional.cross_entropy(logits, labels)
55
+ return {"loss": loss, "logits": logits}
56
+ return {"logits": logits}
57
+
58
+ # resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
59
+ # resnet50d = ResnetModelForImageClassification(resnet50d_config)
60
+ # print(resnet50d)
61
+ # import timm
62
+ #
63
+ # pretrained_model = timm.create_model("resnet50d", pretrained=True)
64
+ # resnet50d.model.load_state_dict(pretrained_model.state_dict())