Spaces:
Runtime error
Runtime error
from torchvision import models as models | |
import torch.nn as nn | |
def model(pretrained, requires_grad): | |
model = models.resnet50(progress=True, pretrained=pretrained) | |
# to freeze the hidden layers | |
if requires_grad == False: | |
for param in model.parameters(): | |
param.requires_grad = False | |
# to train the hidden layers | |
elif requires_grad == True: | |
for param in model.parameters(): | |
param.requires_grad = True | |
# make the classification layer learnable | |
# we have 25 classes in total | |
model.fc = nn.Linear(2048, 25) | |
return model |