import gradio as gr import torch from PIL import Image from torchvision import transforms import warnings import sys import os import contextlib from transformers import ViTForImageClassification # Suppress warnings related to the model weights initialization, FutureWarning and UserWarnings warnings.filterwarnings("ignore", category=UserWarning, module="transformers") warnings.filterwarnings("ignore", category=FutureWarning, module="torch") # Suppress output for copying files and verbose model initialization messages @contextlib.contextmanager def suppress_stdout(): with open(os.devnull, 'w') as devnull: old_stdout = sys.stdout sys.stdout = devnull try: yield finally: sys.stdout = old_stdout # Load the saved model and suppress the warnings with suppress_stdout(): model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6) model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu'))) model.eval() # Define the same transformation used during training transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load the class names (disease types) class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow'] # Function to predict disease type from an image def predict_disease(image_path): # Open the image file img = Image.open(image_path) # Apply transformations to the image img_tensor = transform(img).unsqueeze(0) # Add batch dimension # Make prediction with torch.no_grad(): outputs = model(img_tensor) _, predicted_class = torch.max(outputs.logits, 1) # Get the predicted label predicted_label = class_names[predicted_class.item()] return predicted_label # Test with a new image image_path = 'zoomed_Bacterial%20Blight_614.png' # Replace with your image path predicted_label = predict_disease(image_path) print(f'The predicted disease type is: {predicted_label}')