Luecke commited on
Commit
9d881b0
·
1 Parent(s): c2ac901

separate classification file for prediction only

Browse files
classification/classification_predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, models, transforms
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import DataLoader, random_split
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import random
10
+ from PIL import Image
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ # THIS IS HARDCODED AND DEPENDS ON THE TRAINING. FOR THE MODEL THAT WE TRAINED AND IMPORT HERE, THESE ARE THE CLASSES THAT CAN BE PREDICTED
15
+ class_names = list(['Annonaceae', 'Apocynaceae', 'Astrocarynum aculeatum', 'Attalea', 'Bactris', 'Bellucia', 'Bonomia', 'Byrsonima', 'Caryocar', 'Cecropia palmata', 'Cecropia ulei', 'Cordia', 'Couepia', 'Embaubas', 'Euterpe', 'Fabaceae', 'Genipa', 'Goupia', 'Guatteria', 'Humiria', 'Inga', 'Leopoldinia', 'Licania', 'Malvaceae', 'Meliaceae', 'Miconia', 'Myrcia', 'Ocotea', 'Oenocarpus', 'Pachira', 'Panopsis rubescens', 'Parkia', 'Passovia', 'Phoradendron', 'Pourouma', 'Remijia', 'Ruizterania', 'Salicaceae', 'Sapotaceae', 'Simarouba', 'Tachigali', 'Tapirira', 'Virola', 'Vochysia', 'Xylopia', 'cf. Swartzia'])
16
+ num_classes = 46
17
+
18
+ # Load the pre-trained ResNet50 model
19
+ resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
20
+
21
+ # Freeze the parameters of the pre-trained model
22
+ for param in resnet50.parameters():
23
+ param.requires_grad = False
24
+
25
+ # Remove the final fully connected layer
26
+ num_ftrs = resnet50.fc.in_features
27
+ resnet50.fc = nn.Identity() # Replace the final layer with an identity function to get the feature vectors
28
+
29
+ # Define a custom neural network with one hidden layer and an output layer
30
+ class CustomNet(nn.Module):
31
+ def __init__(self, num_ftrs, num_classes):
32
+ super(CustomNet, self).__init__()
33
+ self.resnet50 = resnet50
34
+ self.hidden = nn.Linear(num_ftrs, 512)
35
+ self.relu = nn.ReLU()
36
+ self.output = nn.Linear(512, num_classes)
37
+
38
+ def forward(self, x):
39
+ x = self.resnet50(x) # Extract features using the pre-trained model
40
+ x = self.hidden(x) # Pass through the hidden layer
41
+ x = self.relu(x) # Apply ReLU activation
42
+ x = self.output(x) # Output layer
43
+ return x
44
+
45
+
46
+ # Function to load and preprocess the image
47
+ def process_image(image_path):
48
+ data_transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
52
+ ])
53
+ image = Image.open(image_path).convert('RGB')
54
+ image = data_transform(image)# data_transforms(image) # <-- data transforms uses all the random cropping as well
55
+ image = image.unsqueeze(0) # Add batch dimension
56
+ return image
57
+
58
+
59
+ # Function to predict the class of a single image
60
+ def predict_single_image(image_path, model):
61
+ # Load the image and preprocess it
62
+ image = process_image(image_path)
63
+
64
+ # Load the model
65
+ model.eval()
66
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
67
+ model = model.to(device)
68
+
69
+ # Pass the image through the model
70
+ with torch.no_grad():
71
+ image = image.to(device)
72
+ outputs = model(image)
73
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
74
+
75
+ # Return the class names and their probabilities as a Pandas Series
76
+ return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False)
77
+
78
+ def classify(img_path):
79
+ # Path to the single image
80
+ image_path = img_path
81
+
82
+ # Initialize your custom model
83
+ model = CustomNet(num_ftrs, num_classes)
84
+ # Load the trained model weights
85
+ model.load_state_dict(torch.load('./fine_tuned:plant_classifier.pth'))
86
+
87
+ # Predict the class probabilities
88
+ class_probabilities = predict_single_image(image_path, model)
89
+ return class_probabilities