Spaces:
Running
Running
File size: 5,129 Bytes
758f3f5 9f6582c 758f3f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
import shutil
import uvicorn
import requests
app = FastAPI()
# Directory where models are stored
MODEL_DIRECTORY = "dsanet_models"
# Plant disease class names
plant_disease_dict = {
"Rice": ['Blight', 'Brown_Spots'],
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
'Corn___Northern_Leaf_Blight', 'Corn___healthy']
}
# Custom Self-Attention Layer
@tf.keras.utils.register_keras_serializable()
class SelfAttention(Layer):
def __init__(self, reduction_ratio=2, **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
n_channels = input_shape[-1] // self.reduction_ratio
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
super(SelfAttention, self).build(input_shape)
def call(self, inputs):
query = self.query_conv(inputs)
key = self.key_conv(inputs)
value = self.value_conv(inputs)
# Calculate attention scores
attention_scores = tf.matmul(query, key, transpose_b=True)
attention_scores = Softmax(axis=1)(attention_scores)
# Apply attention to values
attended_value = tf.matmul(attention_scores, value)
concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
return concatenated_output
def get_config(self):
config = super(SelfAttention, self).get_config()
config.update({"reduction_ratio": self.reduction_ratio})
return config
@app.get("/health")
async def api_health_check():
return JSONResponse(content={"status": "Service is running"})
@app.post("/predict/{plant_name}")
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
"""
API endpoint to predict plant disease from an uploaded image.
Args:
plant_name (str): The plant type (must match a key in `plant_disease_dict`).
file (UploadFile): The image file uploaded by the user.
Returns:
JSON response with the predicted class.
"""
# Ensure the plant name is valid
if plant_name not in plant_disease_dict:
raise HTTPException(status_code=400, detail="Invalid plant name")
# Construct the model path
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
if plant_name == "Rice":
model = load_model(model_path)
else:
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
# Check if the model exists
if not os.path.isfile(model_path):
raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found")
# Save uploaded file temporarily
temp_path = f"temp_image_{file.filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Load model
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
# Load and preprocess the image
img = image.load_img(temp_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
img_array = img_array / 255.0 # Normalize
# Make prediction
prediction = model.predict(img_array)
predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)]
return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class})
finally:
# Clean up temporary file
os.remove(temp_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |