zhangfaen commited on
Commit
e7251dc
·
verified ·
1 Parent(s): 50ca18c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +55 -46
modeling_resnet.py CHANGED
@@ -1,46 +1,55 @@
1
- from transformers import PretrainedConfig
2
- from typing import List
3
- from pprint import pprint
4
-
5
-
6
- class ResnetConfig(PretrainedConfig):
7
- model_type = "faen_resnet"
8
-
9
- def __init__(
10
- self,
11
- block_type="bottleneck",
12
- layers: List[int] = [3, 4, 6, 3],
13
- num_classes: int = 1000,
14
- input_channels: int = 3,
15
- cardinality: int = 1,
16
- base_width: int = 64,
17
- stem_width: int = 64,
18
- stem_type: str = "",
19
- avg_down: bool = False,
20
- **kwargs,
21
- ):
22
- if block_type not in ["basic", "bottleneck"]:
23
- raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
24
- if stem_type not in ["", "deep", "deep-tiered"]:
25
- raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
26
-
27
- self.block_type = block_type
28
- self.layers = layers
29
- self.num_classes = num_classes
30
- self.input_channels = input_channels
31
- self.cardinality = cardinality
32
- self.base_width = base_width
33
- self.stem_width = stem_width
34
- self.stem_type = stem_type
35
- self.avg_down = avg_down
36
- super().__init__(**kwargs)
37
-
38
- if __name__ == "__main__":
39
- resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
40
- print("init a ResnetConfig, it is:\n")
41
- pprint(resnet50d_config)
42
- resnet50d_config.save_pretrained("./")
43
- resnet50d_config = ResnetConfig.from_pretrained("./")
44
- print("\n")
45
- print("saved to file config.json and reload it from config.json and it is:\n")
46
- pprint(resnet50d_config)
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
+ from configuration_resnet import ResnetConfig
4
+ import torch
5
+
6
+ BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
7
+
8
+
9
+ class ResnetModel(PreTrainedModel):
10
+ config_class = ResnetConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ block_layer = BLOCK_MAPPING[config.block_type]
15
+ self.model = ResNet(
16
+ block_layer,
17
+ config.layers,
18
+ num_classes=config.num_classes,
19
+ in_chans=config.input_channels,
20
+ cardinality=config.cardinality,
21
+ base_width=config.base_width,
22
+ stem_width=config.stem_width,
23
+ stem_type=config.stem_type,
24
+ avg_down=config.avg_down,
25
+ )
26
+
27
+ def forward(self, tensor):
28
+ return self.model.forward_features(tensor)
29
+
30
+
31
+ class ResnetModelForImageClassification(PreTrainedModel):
32
+ config_class = ResnetConfig
33
+
34
+ def __init__(self, config):
35
+ super().__init__(config)
36
+ block_layer = BLOCK_MAPPING[config.block_type]
37
+ self.model = ResNet(
38
+ block_layer,
39
+ config.layers,
40
+ num_classes=config.num_classes,
41
+ in_chans=config.input_channels,
42
+ cardinality=config.cardinality,
43
+ base_width=config.base_width,
44
+ stem_width=config.stem_width,
45
+ stem_type=config.stem_type,
46
+ avg_down=config.avg_down,
47
+ )
48
+
49
+ def forward(self, tensor, labels=None):
50
+ logits = self.model(tensor)
51
+ if labels is not None:
52
+ loss = torch.nn.cross_entropy(logits, labels)
53
+ return {"loss": loss, "logits": logits}
54
+ return {"logits": logits}
55
+