WeCanopy / classification /classification_predict.py
Luecke's picture
current pipeline
6d70836
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
import os
import matplotlib.pyplot as plt
import random
from PIL import Image
import numpy as np
import pandas as pd
# THIS IS HARDCODED AND DEPENDS ON THE TRAINING. FOR THE MODEL THAT WE TRAINED AND IMPORT HERE, THESE ARE THE CLASSES THAT CAN BE PREDICTED
class_names = list(['Annonaceae', 'Apocynaceae', 'Astrocarynum aculeatum', 'Attalea', 'Bactris', 'Bellucia', 'Bonomia', 'Byrsonima', 'Caryocar', 'Cecropia palmata', 'Cecropia ulei', 'Cordia', 'Couepia', 'Embaubas', 'Euterpe', 'Fabaceae', 'Genipa', 'Goupia', 'Guatteria', 'Humiria', 'Inga', 'Leopoldinia', 'Licania', 'Malvaceae', 'Meliaceae', 'Miconia', 'Myrcia', 'Ocotea', 'Oenocarpus', 'Pachira', 'Panopsis rubescens', 'Parkia', 'Passovia', 'Phoradendron', 'Pourouma', 'Remijia', 'Ruizterania', 'Salicaceae', 'Sapotaceae', 'Simarouba', 'Tachigali', 'Tapirira', 'Virola', 'Vochysia', 'Xylopia', 'cf. Swartzia'])
num_classes = 46
# Load the pre-trained ResNet50 model
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# Freeze the parameters of the pre-trained model
for param in resnet50.parameters():
param.requires_grad = False
# Remove the final fully connected layer
num_ftrs = resnet50.fc.in_features
resnet50.fc = nn.Identity() # Replace the final layer with an identity function to get the feature vectors
# Define a custom neural network with one hidden layer and an output layer
class CustomNet(nn.Module):
def __init__(self, num_ftrs, num_classes):
super(CustomNet, self).__init__()
self.resnet50 = resnet50
self.hidden = nn.Linear(num_ftrs, 512)
self.relu = nn.ReLU()
self.output = nn.Linear(512, num_classes)
def forward(self, x):
x = self.resnet50(x) # Extract features using the pre-trained model
x = self.hidden(x) # Pass through the hidden layer
x = self.relu(x) # Apply ReLU activation
x = self.output(x) # Output layer
return x
# Function to load and preprocess the image
def process_image(image_path):
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image = data_transform(image)# data_transforms(image) # <-- data transforms uses all the random cropping as well
image = image.unsqueeze(0) # Add batch dimension
return image
# Function to predict the class of a single image
def predict_single_image(image_path, model):
# Load the image and preprocess it
image = process_image(image_path)
# Load the model
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Pass the image through the model
with torch.no_grad():
image = image.to(device)
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Return the class names and their probabilities as a Pandas Series
return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False)
def classify(img_path):
# Path to the single image
image_path = img_path
# Initialize your custom model
model = CustomNet(num_ftrs, num_classes)
# Load the trained model weights
model.load_state_dict(torch.load('./classification/fine_tuned_plant_classifier.pth'))
# Predict the class probabilities
class_probabilities = predict_single_image(image_path, model)
return class_probabilities