Spaces:
Sleeping
Sleeping
File size: 3,298 Bytes
cbc5566 9dfc63c 2338cca cbc5566 38d7439 cbc5566 38d7439 9dfc63c eff8876 9dfc63c 38d7439 9dfc63c 38d7439 9dfc63c 38d7439 9dfc63c 38d7439 9dfc63c 38d7439 9dfc63c 38d7439 9dfc63c 38d7439 2b3983d 38d7439 9dfc63c 2b3983d 38d7439 9dfc63c fb31436 a62d15d 38d7439 a62d15d 38d7439 2338cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
from flask import Flask, request, jsonify
# Define the number of classes
num_classes = 2 # Update with the actual number of classes in your dataset (e.g., 2 for healthy and anomalous)
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False) # Set pretrained=False because you're loading custom weights
model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust for the number of classes in your dataset
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model on CPU for compatibility
model.eval() # Set to evaluation mode
return model
# Download the model and load it
model_path = download_model() # Downloads the model from Hugging Face Hub
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256
transforms.CenterCrop(224), # Crop the image to 224x224
transforms.ToTensor(), # Convert the image to a Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize the image (ImageNet mean and std)
])
# Define the prediction function
def predict(image):
# Apply the necessary transformations to the image
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
with torch.no_grad():
outputs = model(image) # Perform forward pass
predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class
# Create a response based on the predicted class
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy wheat image."
else:
return "Unexpected class prediction."
# Create the Gradio interface
iface = gr.Interface(
fn=predict, # Function for prediction
inputs=gr.Image(type="pil"), # Image input
outputs=gr.Textbox(), # Output: Predicted class
live=True, # Updates as the user uploads an image
title="Wheat Anomaly Detection",
description="Upload an image of wheat to detect anomalies like disease or pest infestation."
)
# Launch the Gradio interface
iface.launch(share=True)
# Create a Flask app
app = Flask(__name__)
# Define the API endpoint
@app.route('/predict', methods=['POST'])
def api_predict():
try:
data = request.json
image_path = data['inputs']
# Load the image
image = Image.open(image_path)
# Perform prediction
result = predict(image)
return jsonify({"result": result})
except Exception as e:
return jsonify({"error": str(e)}), 400
# Run the Flask app
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000) |