File size: 1,490 Bytes
68aa7ad
a9f69a2
68aa7ad
 
a9f69a2
68aa7ad
a9f69a2
68aa7ad
a9f69a2
 
68aa7ad
a9f69a2
 
 
 
 
 
 
68aa7ad
 
 
a9f69a2
 
 
 
 
 
 
 
 
68aa7ad
 
 
 
 
 
a9f69a2
 
 
68aa7ad
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from transformers import AutoModelForImageClassification, AutoConfig
from torchvision import transforms
from PIL import Image
import torch

# Load the model
MODEL_NAME = "dwililiya/sugarcane-plant-diseases-classification"
config = AutoConfig.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, config=config)

# Define a transform to prepare the image
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Define class names
class_names = ['Bacterial Blight', 'Healthy', 'Mosaic', 'Red Rot', 'Rust', 'Yellow']

def predict(image):
    # Transform the image
    image = transform(image).unsqueeze(0)  # Add batch dimension

    # Perform inference
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.logits, 1)
        predicted_class = class_names[predicted.item()]
        confidence = torch.softmax(outputs.logits, dim=1)[0][predicted].item()

    return predicted_class, confidence

# Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Sugarcane Leaf Image"),  # Change to 'pil'
    outputs=[gr.Label(num_top_classes=1, label="Predicted Class"),
             gr.Textbox(label="Confidence Score")],
    title="Sugarcane Plant Diseases Classification",
    description="Upload an image of a sugarcane leaf to classify its disease.",
)

if __name__ == "__main__":
    iface.launch()