Spaces:
Sleeping
Sleeping
File size: 2,594 Bytes
2201868 163e73a 2201868 0f694f1 fa3ae41 2201868 fa3ae41 163e73a fa3ae41 163e73a fa3ae41 163e73a fa3ae41 163e73a fa3ae41 52fd9c2 fa3ae41 b77b937 fa3ae41 2255b93 fa3ae41 5b86dff fa3ae41 163e73a fa3ae41 163e73a fa3ae41 163e73a fa3ae41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import os
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define transformation for image processing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Function to load and preprocess image
def load_image_from_path(image_path):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found at {image_path}")
image = Image.open(image_path)
image = transform(image).unsqueeze(0) # Convert to tensor and add batch dimension
return image
# Load the model (Example: ResNet50)
def load_model():
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("model.pth"))
model.eval()
return model
# Predict from image tensor
def predict(image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
return predicted_class
# Initialize model
model = load_model()
# Define the Gradio interface function
def predict_from_file(file_path):
try:
# Load image from path
image_tensor = load_image_from_path(file_path)
# Get prediction
predicted_class = predict(image_tensor)
result = {"result": "Fall armyworm" if predicted_class == 0 else "Healthy maize"}
return result
except Exception as e:
return {"error": str(e)}
# Gradio Interface
iface = gr.Interface(
fn=predict_from_file,
inputs=gr.Textbox(label="Image Path (Local)"),
outputs=gr.JSON(),
live=True,
title="Maize Anomaly Detection",
description="Send a local file path via POST request to trigger prediction.",
)
# Launch the Gradio app
iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
|