izeeek's picture
Update app.py
e7f56c5 verified
raw
history blame
2.17 kB
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import requests
from torchvision.models import vgg19
import gradio as gr
# Define preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert images to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
])
# Load trained model
model = models.vgg19(weights='DEFAULT')
# Adjust the final fully connected layer for binary classification
num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer
model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification
# Load the saved weights into the model
model.load_state_dict(torch.load('rice_plant_classification.pth', weights_only=True)) # Ensure this file exists
model.eval()
# Define class labels
class_to_label = {0: 'Healthy', 1: 'Unhealthy'}
# Inference function
def predict(image):
# Preprocess the image
img = Image.fromarray(image)
img = preprocess(img).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(img)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, 1).item()
confidence = probabilities[0][predicted_class].item()
# Return the class label and confidence
return class_to_label[predicted_class], f'{confidence * 100:.2f}%'
example_images = ["healthy.jpg", "unhealthy.jpg"]
# Create Gradio interface
interface = gr.Interface(fn=predict,
inputs="image",
outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
title="Healthy vs Unhealthy Rice Plant Classifier",
description="Upload a rice plant image to classify either it is healthy or unhealthy.",
examples=example_images
)
# Launch the app
if __name__ == "__main__":
interface.launch()