Spaces:
Runtime error
Runtime error
import timm | |
import torch.nn as nn | |
import torch | |
def get_efficientnet(model_name): | |
model = timm.create_model(model_name, pretrained=True) | |
return model | |
class CustomEfficientNet(nn.Module): | |
""" | |
This class defines a custom EfficientNet network. | |
Parameters | |
---------- | |
target_size : int | |
Number of units for the output layer. | |
pretrained : bool | |
Determine if pretrained weights are used. | |
Attributes | |
---------- | |
model : nn.Module | |
EfficientNet model. | |
""" | |
def __init__(self, model_name : str = 'efficientnet_b0', | |
target_size : int = 4, pretrained : bool = True): | |
super().__init__() | |
self.model = timm.create_model(model_name, pretrained=pretrained) | |
# Modify the classifier layer | |
in_features = self.model.classifier.in_features | |
self.model.classifier = nn.Sequential( | |
#nn.Dropout(0.5), | |
nn.Linear(in_features, 256), | |
nn.ReLU(), | |
#nn.Dropout(0.5), | |
nn.Linear(256, target_size) | |
) | |
def forward(self, x : torch.Tensor) -> torch.Tensor: | |
x = self.model(x) | |
return x | |