|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torchvision import datasets, models, transforms |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import DataLoader, random_split |
|
import os |
|
import matplotlib.pyplot as plt |
|
import random |
|
from PIL import Image |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
class_names = list(['Annonaceae', 'Apocynaceae', 'Astrocarynum aculeatum', 'Attalea', 'Bactris', 'Bellucia', 'Bonomia', 'Byrsonima', 'Caryocar', 'Cecropia palmata', 'Cecropia ulei', 'Cordia', 'Couepia', 'Embaubas', 'Euterpe', 'Fabaceae', 'Genipa', 'Goupia', 'Guatteria', 'Humiria', 'Inga', 'Leopoldinia', 'Licania', 'Malvaceae', 'Meliaceae', 'Miconia', 'Myrcia', 'Ocotea', 'Oenocarpus', 'Pachira', 'Panopsis rubescens', 'Parkia', 'Passovia', 'Phoradendron', 'Pourouma', 'Remijia', 'Ruizterania', 'Salicaceae', 'Sapotaceae', 'Simarouba', 'Tachigali', 'Tapirira', 'Virola', 'Vochysia', 'Xylopia', 'cf. Swartzia']) |
|
num_classes = 46 |
|
|
|
|
|
resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT') |
|
|
|
|
|
for param in resnet50.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
num_ftrs = resnet50.fc.in_features |
|
resnet50.fc = nn.Identity() |
|
|
|
|
|
class CustomNet(nn.Module): |
|
def __init__(self, num_ftrs, num_classes): |
|
super(CustomNet, self).__init__() |
|
self.resnet50 = resnet50 |
|
self.hidden = nn.Linear(num_ftrs, 512) |
|
self.relu = nn.ReLU() |
|
self.output = nn.Linear(512, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.resnet50(x) |
|
x = self.hidden(x) |
|
x = self.relu(x) |
|
x = self.output(x) |
|
return x |
|
|
|
|
|
|
|
def process_image(image_path): |
|
data_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
image = Image.open(image_path).convert('RGB') |
|
image = data_transform(image) |
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
|
|
|
|
def predict_single_image(image_path, model): |
|
|
|
image = process_image(image_path) |
|
|
|
|
|
model.eval() |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
image = image.to(device) |
|
outputs = model(image) |
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
|
|
|
return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False) |
|
|
|
def classify(img_path): |
|
|
|
image_path = img_path |
|
|
|
|
|
model = CustomNet(num_ftrs, num_classes) |
|
|
|
model.load_state_dict(torch.load('./classification/fine_tuned_plant_classifier.pth')) |
|
|
|
|
|
class_probabilities = predict_single_image(image_path, model) |
|
return class_probabilities |