Spaces:
Running
Running
File size: 2,767 Bytes
81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd 8445179 81f02dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import models, transforms
from ultralytics import YOLO
import gradio as gr
import torch.nn as nn
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
yolo_model = YOLO('best.pt') # Make sure this file is uploaded
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Linear(resnet.fc.in_features, 3)
resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
resnet = resnet.to(device)
resnet.eval()
# Class labels
class_labels = ["c9", "kant", "superf"]
# Image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def classify_crop(crop_img):
"""Classify a single rice grain"""
image = transform(crop_img).unsqueeze(0).to(device)
with torch.no_grad():
output = resnet(image)
_, predicted = torch.max(output, 1)
return class_labels[predicted.item()]
def detect_and_classify(input_image):
"""Process uploaded image"""
# Convert Gradio Image to OpenCV format
image = np.array(input_image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# YOLO Detection
results = yolo_model(image)[0]
boxes = results.boxes.xyxy.cpu().numpy()
# Process each detection
for box in boxes:
x1, y1, x2, y2 = map(int, box[:4])
crop = image[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
predicted_label = classify_crop(crop_pil)
# Draw bounding box and label
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image,
predicted_label,
(x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(36, 255, 12),
2)
# Convert back to RGB for Gradio
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Create Gradio interface
with gr.Blocks(title="Rice Classification") as demo:
gr.Markdown("""
## 🍚 Rice Variety Classifier
Upload an image containing rice grains. The system will detect and classify each grain.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Rice Image")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_image = gr.Image(label="Detection Results", interactive=False)
submit_btn.click(
fn=detect_and_classify,
inputs=image_input,
outputs=output_image
)
# Launch the app
demo.launch() |