Spaces:
Runtime error
Runtime error
File size: 1,187 Bytes
9fbf078 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import timm
import torch.nn as nn
import torch
def get_efficientnet(model_name):
model = timm.create_model(model_name, pretrained=True)
return model
class CustomEfficientNet(nn.Module):
"""
This class defines a custom EfficientNet network.
Parameters
----------
target_size : int
Number of units for the output layer.
pretrained : bool
Determine if pretrained weights are used.
Attributes
----------
model : nn.Module
EfficientNet model.
"""
def __init__(self, model_name : str = 'efficientnet_b0',
target_size : int = 4, pretrained : bool = True):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained)
# Modify the classifier layer
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
#nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
#nn.Dropout(0.5),
nn.Linear(256, target_size)
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
|