Rjavenger commited on
Commit
24009ab
·
verified ·
1 Parent(s): 7a59dab

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +37 -0
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from torchvision.models import resnet18
5
+
6
+ # Same as your original custom classifier
7
+ class ResNet18Classifier(torch.nn.Module):
8
+ def __init__(self, num_classes=3, pretrained=False):
9
+ super(ResNet18Classifier, self).__init__()
10
+ self.model = resnet18(pretrained=pretrained)
11
+ self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
12
+
13
+ def forward(self, x):
14
+ return self.model(x)
15
+
16
+ def load_model(model_path="model/best_classification_model.pth", num_classes=3):
17
+ model = ResNet18Classifier(num_classes=num_classes, pretrained=False)
18
+ state_dict = torch.load(model_path, map_location='cpu')
19
+ model.load_state_dict(state_dict)
20
+ model.eval()
21
+ return model
22
+
23
+ def predict_image(image_path, model, class_names):
24
+ transform = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
28
+ std=[0.229, 0.224, 0.225])
29
+ ])
30
+ image = Image.open(image_path).convert('RGB')
31
+ image_tensor = transform(image).unsqueeze(0)
32
+
33
+ with torch.no_grad():
34
+ outputs = model(image_tensor)
35
+ _, predicted = torch.max(outputs, 1)
36
+
37
+ return class_names[predicted.item()]