separate classification file for prediction only
Browse files
classification/classification_predict.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torchvision import datasets, models, transforms
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from torch.utils.data import DataLoader, random_split
|
7 |
+
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import random
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
# THIS IS HARDCODED AND DEPENDS ON THE TRAINING. FOR THE MODEL THAT WE TRAINED AND IMPORT HERE, THESE ARE THE CLASSES THAT CAN BE PREDICTED
|
15 |
+
class_names = list(['Annonaceae', 'Apocynaceae', 'Astrocarynum aculeatum', 'Attalea', 'Bactris', 'Bellucia', 'Bonomia', 'Byrsonima', 'Caryocar', 'Cecropia palmata', 'Cecropia ulei', 'Cordia', 'Couepia', 'Embaubas', 'Euterpe', 'Fabaceae', 'Genipa', 'Goupia', 'Guatteria', 'Humiria', 'Inga', 'Leopoldinia', 'Licania', 'Malvaceae', 'Meliaceae', 'Miconia', 'Myrcia', 'Ocotea', 'Oenocarpus', 'Pachira', 'Panopsis rubescens', 'Parkia', 'Passovia', 'Phoradendron', 'Pourouma', 'Remijia', 'Ruizterania', 'Salicaceae', 'Sapotaceae', 'Simarouba', 'Tachigali', 'Tapirira', 'Virola', 'Vochysia', 'Xylopia', 'cf. Swartzia'])
|
16 |
+
num_classes = 46
|
17 |
+
|
18 |
+
# Load the pre-trained ResNet50 model
|
19 |
+
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
|
20 |
+
|
21 |
+
# Freeze the parameters of the pre-trained model
|
22 |
+
for param in resnet50.parameters():
|
23 |
+
param.requires_grad = False
|
24 |
+
|
25 |
+
# Remove the final fully connected layer
|
26 |
+
num_ftrs = resnet50.fc.in_features
|
27 |
+
resnet50.fc = nn.Identity() # Replace the final layer with an identity function to get the feature vectors
|
28 |
+
|
29 |
+
# Define a custom neural network with one hidden layer and an output layer
|
30 |
+
class CustomNet(nn.Module):
|
31 |
+
def __init__(self, num_ftrs, num_classes):
|
32 |
+
super(CustomNet, self).__init__()
|
33 |
+
self.resnet50 = resnet50
|
34 |
+
self.hidden = nn.Linear(num_ftrs, 512)
|
35 |
+
self.relu = nn.ReLU()
|
36 |
+
self.output = nn.Linear(512, num_classes)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.resnet50(x) # Extract features using the pre-trained model
|
40 |
+
x = self.hidden(x) # Pass through the hidden layer
|
41 |
+
x = self.relu(x) # Apply ReLU activation
|
42 |
+
x = self.output(x) # Output layer
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
# Function to load and preprocess the image
|
47 |
+
def process_image(image_path):
|
48 |
+
data_transform = transforms.Compose([
|
49 |
+
transforms.Resize((224, 224)),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
52 |
+
])
|
53 |
+
image = Image.open(image_path).convert('RGB')
|
54 |
+
image = data_transform(image)# data_transforms(image) # <-- data transforms uses all the random cropping as well
|
55 |
+
image = image.unsqueeze(0) # Add batch dimension
|
56 |
+
return image
|
57 |
+
|
58 |
+
|
59 |
+
# Function to predict the class of a single image
|
60 |
+
def predict_single_image(image_path, model):
|
61 |
+
# Load the image and preprocess it
|
62 |
+
image = process_image(image_path)
|
63 |
+
|
64 |
+
# Load the model
|
65 |
+
model.eval()
|
66 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
67 |
+
model = model.to(device)
|
68 |
+
|
69 |
+
# Pass the image through the model
|
70 |
+
with torch.no_grad():
|
71 |
+
image = image.to(device)
|
72 |
+
outputs = model(image)
|
73 |
+
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
74 |
+
|
75 |
+
# Return the class names and their probabilities as a Pandas Series
|
76 |
+
return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False)
|
77 |
+
|
78 |
+
def classify(img_path):
|
79 |
+
# Path to the single image
|
80 |
+
image_path = img_path
|
81 |
+
|
82 |
+
# Initialize your custom model
|
83 |
+
model = CustomNet(num_ftrs, num_classes)
|
84 |
+
# Load the trained model weights
|
85 |
+
model.load_state_dict(torch.load('./fine_tuned:plant_classifier.pth'))
|
86 |
+
|
87 |
+
# Predict the class probabilities
|
88 |
+
class_probabilities = predict_single_image(image_path, model)
|
89 |
+
return class_probabilities
|