Rjavenger commited on
Commit
b9ee509
·
verified ·
1 Parent(s): bc72f23

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +9 -9
utils.py CHANGED
@@ -1,20 +1,20 @@
1
  import torch
 
2
  from torchvision import transforms
3
  from PIL import Image
4
  from torchvision.models import resnet18
5
 
6
- # Same as your original custom classifier
7
- class ResNet18Classifier(torch.nn.Module):
8
- def __init__(self, num_classes=3, pretrained=False):
9
- super(ResNet18Classifier, self).__init__()
10
- self.model = resnet18(pretrained=pretrained)
11
- self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
12
 
13
  def forward(self, x):
14
- return self.model(x)
15
 
16
  def load_model(model_path="model/best_classification_model.pth", num_classes=3):
17
- model = ResNet18Classifier(num_classes=num_classes, pretrained=False)
18
  state_dict = torch.load(model_path, map_location='cpu')
19
  model.load_state_dict(state_dict)
20
  model.eval()
@@ -34,4 +34,4 @@ def predict_image(image_path, model, class_names):
34
  outputs = model(image_tensor)
35
  _, predicted = torch.max(outputs, 1)
36
 
37
- return class_names[predicted.item()]
 
1
  import torch
2
+ import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
  from torchvision.models import resnet18
6
 
7
+ class ResNet18Classifier(nn.Module):
8
+ def __init__(self, num_classes=3):
9
+ super().__init__()
10
+ self.resnet = resnet18(weights=None) # modern way
11
+ self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
 
12
 
13
  def forward(self, x):
14
+ return self.resnet(x)
15
 
16
  def load_model(model_path="model/best_classification_model.pth", num_classes=3):
17
+ model = ResNet18Classifier(num_classes=num_classes)
18
  state_dict = torch.load(model_path, map_location='cpu')
19
  model.load_state_dict(state_dict)
20
  model.eval()
 
34
  outputs = model(image_tensor)
35
  _, predicted = torch.max(outputs, 1)
36
 
37
+ return class_names[predicted.item()]