Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
from torchvision import transforms | |
from torch import nn | |
def create_ResNetb34_model(num_classes:int=3,seed:int=42): | |
""" | |
Creates an ResNetb34 feature extractor model and transforms. | |
:param num_classes: number of classes in classifier head. | |
Defaults to 3. | |
:param seed: random seed value. | |
Defaults to 42. | |
:return: feature extractor model. | |
transforms (torchvision.transforms): ResNetb34 image transforms. | |
""" | |
# 1. Setup pretrained EffNetB1 weights | |
weigts = torchvision.models.ResNet34_Weights.DEFAULT | |
# 2. Get EffNetB2 transforms | |
transform = transforms.Compose([ | |
weigts.transforms(), | |
#transforms.RandomHorizontalFlip(), | |
]) | |
# 3. Setup pretrained model | |
model=torchvision.models.resnet34(weights= "DEFAULT") | |
# 4. Freeze the base layers in the model (this will freeze all layers to begin with) | |
for param in model.parameters(): | |
param.requires_grad=True | |
# 5. Change classifier head with random seed for reproducibility | |
torch.manual_seed(seed) | |
model.classifier=nn.Sequential(nn.Dropout(p=0.2,inplace=True), | |
nn.Linear(in_features=612,out_features=num_classes)) | |
return model,transform | |