Spaces:
Runtime error
Runtime error
File size: 589 Bytes
cb8043e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from torchvision import models as models
import torch.nn as nn
def model(pretrained, requires_grad):
model = models.resnet50(progress=True, pretrained=pretrained)
# to freeze the hidden layers
if requires_grad == False:
for param in model.parameters():
param.requires_grad = False
# to train the hidden layers
elif requires_grad == True:
for param in model.parameters():
param.requires_grad = True
# make the classification layer learnable
# we have 25 classes in total
model.fc = nn.Linear(2048, 25)
return model |