young-man commited on
Commit
51dddfe
·
1 Parent(s): 0a1f279

Delete modeling_resnet.py

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +0 -54
modeling_resnet.py DELETED
@@ -1,54 +0,0 @@
1
- from transformers import PreTrainedModel
2
- from timm.models.resnet import BasicBlock, Bottleneck, ResNet
3
- from .configuration_resnet import ResnetConfig
4
- import torch
5
-
6
- BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
7
-
8
-
9
- class ResnetModel(PreTrainedModel):
10
- config_class = ResnetConfig
11
-
12
- def __init__(self, config):
13
- super().__init__(config)
14
- block_layer = BLOCK_MAPPING[config.block_type]
15
- self.model = ResNet(
16
- block_layer,
17
- config.layers,
18
- num_classes=config.num_classes,
19
- in_chans=config.input_channels,
20
- cardinality=config.cardinality,
21
- base_width=config.base_width,
22
- stem_width=config.stem_width,
23
- stem_type=config.stem_type,
24
- avg_down=config.avg_down,
25
- )
26
-
27
- def forward(self, tensor):
28
- return self.model.forward_features(tensor)
29
-
30
-
31
- class ResnetModelForImageClassification(PreTrainedModel):
32
- config_class = ResnetConfig
33
-
34
- def __init__(self, config):
35
- super().__init__(config)
36
- block_layer = BLOCK_MAPPING[config.block_type]
37
- self.model = ResNet(
38
- block_layer,
39
- config.layers,
40
- num_classes=config.num_classes,
41
- in_chans=config.input_channels,
42
- cardinality=config.cardinality,
43
- base_width=config.base_width,
44
- stem_width=config.stem_width,
45
- stem_type=config.stem_type,
46
- avg_down=config.avg_down,
47
- )
48
-
49
- def forward(self, tensor, labels=None):
50
- logits = self.model(tensor)
51
- if labels is not None:
52
- loss = torch.nn.cross_entropy(logits, labels)
53
- return {"loss": loss, "logits": logits}
54
- return {"logits": logits}