File size: 926 Bytes
56176e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import PreTrainedModel
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch import nn
import torch.nn.functional as F
from .config import MobileNetV3Config

class MobileNetV3Model(PreTrainedModel):
    config_class = MobileNetV3Config

    def __init__(self, config):
        super().__init__(config)
        self.model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
        self.model.classifier = nn.Sequential(
            nn.Linear(960, 1280),
            nn.Hardswish(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(1280, config.num_classes),
        )
        
    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}