navpan2's picture
Update main.py
9f6582c verified
raw
history blame
5.13 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
import shutil
import uvicorn
import requests
app = FastAPI()
# Directory where models are stored
MODEL_DIRECTORY = "dsanet_models"
# Plant disease class names
plant_disease_dict = {
"Rice": ['Blight', 'Brown_Spots'],
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
'Corn___Northern_Leaf_Blight', 'Corn___healthy']
}
# Custom Self-Attention Layer
@tf.keras.utils.register_keras_serializable()
class SelfAttention(Layer):
def __init__(self, reduction_ratio=2, **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
n_channels = input_shape[-1] // self.reduction_ratio
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
super(SelfAttention, self).build(input_shape)
def call(self, inputs):
query = self.query_conv(inputs)
key = self.key_conv(inputs)
value = self.value_conv(inputs)
# Calculate attention scores
attention_scores = tf.matmul(query, key, transpose_b=True)
attention_scores = Softmax(axis=1)(attention_scores)
# Apply attention to values
attended_value = tf.matmul(attention_scores, value)
concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
return concatenated_output
def get_config(self):
config = super(SelfAttention, self).get_config()
config.update({"reduction_ratio": self.reduction_ratio})
return config
@app.get("/health")
async def api_health_check():
return JSONResponse(content={"status": "Service is running"})
@app.post("/predict/{plant_name}")
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
"""
API endpoint to predict plant disease from an uploaded image.
Args:
plant_name (str): The plant type (must match a key in `plant_disease_dict`).
file (UploadFile): The image file uploaded by the user.
Returns:
JSON response with the predicted class.
"""
# Ensure the plant name is valid
if plant_name not in plant_disease_dict:
raise HTTPException(status_code=400, detail="Invalid plant name")
# Construct the model path
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
if plant_name == "Rice":
model = load_model(model_path)
else:
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
# Check if the model exists
if not os.path.isfile(model_path):
raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found")
# Save uploaded file temporarily
temp_path = f"temp_image_{file.filename}"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Load model
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
# Load and preprocess the image
img = image.load_img(temp_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
img_array = img_array / 255.0 # Normalize
# Make prediction
prediction = model.predict(img_array)
predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)]
return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class})
finally:
# Clean up temporary file
os.remove(temp_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)