File size: 926 Bytes
56176e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import PreTrainedModel
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch import nn
import torch.nn.functional as F
from .config import MobileNetV3Config
class MobileNetV3Model(PreTrainedModel):
config_class = MobileNetV3Config
def __init__(self, config):
super().__init__(config)
self.model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
self.model.classifier = nn.Sequential(
nn.Linear(960, 1280),
nn.Hardswish(),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(1280, config.num_classes),
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits} |