hafizarslan's picture
Update app.py
8445179 verified
raw
history blame
2.77 kB
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import models, transforms
from ultralytics import YOLO
import gradio as gr
import torch.nn as nn
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
yolo_model = YOLO('best.pt') # Make sure this file is uploaded
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Linear(resnet.fc.in_features, 3)
resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
resnet = resnet.to(device)
resnet.eval()
# Class labels
class_labels = ["c9", "kant", "superf"]
# Image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def classify_crop(crop_img):
"""Classify a single rice grain"""
image = transform(crop_img).unsqueeze(0).to(device)
with torch.no_grad():
output = resnet(image)
_, predicted = torch.max(output, 1)
return class_labels[predicted.item()]
def detect_and_classify(input_image):
"""Process uploaded image"""
# Convert Gradio Image to OpenCV format
image = np.array(input_image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# YOLO Detection
results = yolo_model(image)[0]
boxes = results.boxes.xyxy.cpu().numpy()
# Process each detection
for box in boxes:
x1, y1, x2, y2 = map(int, box[:4])
crop = image[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
predicted_label = classify_crop(crop_pil)
# Draw bounding box and label
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image,
predicted_label,
(x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(36, 255, 12),
2)
# Convert back to RGB for Gradio
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Create Gradio interface
with gr.Blocks(title="Rice Classification") as demo:
gr.Markdown("""
## 🍚 Rice Variety Classifier
Upload an image containing rice grains. The system will detect and classify each grain.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Rice Image")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_image = gr.Image(label="Detection Results", interactive=False)
submit_btn.click(
fn=detect_and_classify,
inputs=image_input,
outputs=output_image
)
# Launch the app
demo.launch()