Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import warnings | |
import sys | |
import os | |
import contextlib | |
from transformers import ViTForImageClassification | |
# Suppress warnings related to the model weights initialization, FutureWarning and UserWarnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch") | |
# Suppress output for copying files and verbose model initialization messages | |
def suppress_stdout(): | |
with open(os.devnull, 'w') as devnull: | |
old_stdout = sys.stdout | |
sys.stdout = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
# Load the saved model and suppress the warnings | |
with suppress_stdout(): | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6) | |
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define the same transformation used during training | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load the class names (disease types) | |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow'] | |
# Function to predict disease type from an image | |
def predict_disease(image_path): | |
# Open the image file | |
img = Image.open(image_path) | |
# Apply transformations to the image | |
img_tensor = transform(img).unsqueeze(0) # Add batch dimension | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, predicted_class = torch.max(outputs.logits, 1) | |
# Get the predicted label | |
predicted_label = class_names[predicted_class.item()] | |
return predicted_label | |
# Test with a new image | |
image_path = 'img1.png' # Replace with your image path | |
predicted_label = predict_disease(image_path) | |
print(f'The predicted disease type is: {predicted_label}') | |