import gradio as gr import json import torch from torch import nn from torchvision import models, transforms from PIL import Image import os # Define the number of classes num_classes = 2 # Define transformation for image processing transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Function to load and preprocess image def load_image_from_path(image_path): if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found at {image_path}") image = Image.open(image_path) image = transform(image).unsqueeze(0) # Convert to tensor and add batch dimension return image # Load the model (Example: ResNet50) def load_model(): model = models.resnet50(pretrained=True) model.fc = nn.Linear(model.fc.in_features, num_classes) model.load_state_dict(torch.load("model.pth")) model.eval() return model # Predict from image tensor def predict(image_tensor): with torch.no_grad(): outputs = model(image_tensor) predicted_class = torch.argmax(outputs, dim=1).item() return predicted_class # Initialize model model = load_model() # Define the Gradio interface function def predict_from_file(file_path): try: # Load image from path image_tensor = load_image_from_path(file_path) # Get prediction predicted_class = predict(image_tensor) result = {"result": "Fall armyworm" if predicted_class == 0 else "Healthy maize"} return result except Exception as e: return {"error": str(e)} # Gradio Interface iface = gr.Interface( fn=predict_from_file, inputs=gr.Textbox(label="Image Path (Local)"), outputs=gr.JSON(), live=True, title="Maize Anomaly Detection", description="Send a local file path via POST request to trigger prediction.", ) # Launch the Gradio app iface.launch(share=True, server_name="0.0.0.0", server_port=7860)