File size: 3,741 Bytes
9d881b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d70836
9d881b0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
import os
import matplotlib.pyplot as plt
import random
from PIL import Image
import numpy as np
import pandas as pd

# THIS IS HARDCODED AND DEPENDS ON THE TRAINING. FOR THE MODEL THAT WE TRAINED AND IMPORT HERE, THESE ARE THE CLASSES THAT CAN BE PREDICTED
class_names = list(['Annonaceae', 'Apocynaceae', 'Astrocarynum aculeatum', 'Attalea', 'Bactris', 'Bellucia', 'Bonomia', 'Byrsonima', 'Caryocar', 'Cecropia palmata', 'Cecropia ulei', 'Cordia', 'Couepia', 'Embaubas', 'Euterpe', 'Fabaceae', 'Genipa', 'Goupia', 'Guatteria', 'Humiria', 'Inga', 'Leopoldinia', 'Licania', 'Malvaceae', 'Meliaceae', 'Miconia', 'Myrcia', 'Ocotea', 'Oenocarpus', 'Pachira', 'Panopsis rubescens', 'Parkia', 'Passovia', 'Phoradendron', 'Pourouma', 'Remijia', 'Ruizterania', 'Salicaceae', 'Sapotaceae', 'Simarouba', 'Tachigali', 'Tapirira', 'Virola', 'Vochysia', 'Xylopia', 'cf. Swartzia'])
num_classes = 46

# Load the pre-trained ResNet50 model
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')

# Freeze the parameters of the pre-trained model
for param in resnet50.parameters():
    param.requires_grad = False

# Remove the final fully connected layer
num_ftrs = resnet50.fc.in_features
resnet50.fc = nn.Identity()  # Replace the final layer with an identity function to get the feature vectors

# Define a custom neural network with one hidden layer and an output layer
class CustomNet(nn.Module):
    def __init__(self, num_ftrs, num_classes):
        super(CustomNet, self).__init__()
        self.resnet50 = resnet50
        self.hidden = nn.Linear(num_ftrs, 512)
        self.relu = nn.ReLU()
        self.output = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.resnet50(x)  # Extract features using the pre-trained model
        x = self.hidden(x)  # Pass through the hidden layer
        x = self.relu(x)  # Apply ReLU activation
        x = self.output(x)  # Output layer
        return x


# Function to load and preprocess the image
def process_image(image_path):
    data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = data_transform(image)# data_transforms(image) # <-- data transforms uses all the random cropping as well
    image = image.unsqueeze(0)  # Add batch dimension
    return image


# Function to predict the class of a single image
def predict_single_image(image_path, model):
    # Load the image and preprocess it
    image = process_image(image_path)

    # Load the model
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Pass the image through the model
    with torch.no_grad():
        image = image.to(device)
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

    # Return the class names and their probabilities as a Pandas Series
    return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False)

def classify(img_path):
    # Path to the single image
    image_path = img_path

    # Initialize your custom model
    model = CustomNet(num_ftrs, num_classes)
    # Load the trained model weights
    model.load_state_dict(torch.load('./classification/fine_tuned_plant_classifier.pth'))

    # Predict the class probabilities
    class_probabilities = predict_single_image(image_path, model)
    return class_probabilities