sehyun
First Commit
afb90db
raw
history blame contribute delete
709 Bytes
import torch, torchvision
from torch import nn
def build_effnetb1():
# weight & model initialization
effnetb1_weights = torchvision.models.EfficientNet_B1_Weights.DEFAULT
effnetb1 = torchvision.models.efficientnet_b1(weights=effnetb1_weights)
effnetb1_transforms = effnetb1_weights.transforms()
effnetb1.name = "effnetb1"
# Freeze params
for params in effnetb1.parameters():
params.requires_grad = False
# Edit Classifiers
effnetb1.classifier = torch.nn.Sequential(
nn.Dropout(p=0, inplace=True),
nn.Linear(in_features=1280,
out_features=4,
bias=True).to("cpu")
)
return effnetb1, effnetb1_transforms