from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse import tensorflow as tf import numpy as np import os from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate import shutil import uvicorn import requests app = FastAPI() # Directory where models are stored MODEL_DIRECTORY = "dsanet_models" # Plant disease class names plant_disease_dict = { "Rice": ['Blight', 'Brown_Spots'], "Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'], "Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'], "Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'], "Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'], "Peach": ['Peach___Bacterial_spot', 'Peach___healthy'], "Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'], "Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'], "Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'], "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy'] } # Custom Self-Attention Layer @tf.keras.utils.register_keras_serializable() class SelfAttention(Layer): def __init__(self, reduction_ratio=2, **kwargs): super(SelfAttention, self).__init__(**kwargs) self.reduction_ratio = reduction_ratio def build(self, input_shape): n_channels = input_shape[-1] // self.reduction_ratio self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) super(SelfAttention, self).build(input_shape) def call(self, inputs): query = self.query_conv(inputs) key = self.key_conv(inputs) value = self.value_conv(inputs) # Calculate attention scores attention_scores = tf.matmul(query, key, transpose_b=True) attention_scores = Softmax(axis=1)(attention_scores) # Apply attention to values attended_value = tf.matmul(attention_scores, value) concatenated_output = Concatenate(axis=-1)([inputs, attended_value]) return concatenated_output def get_config(self): config = super(SelfAttention, self).get_config() config.update({"reduction_ratio": self.reduction_ratio}) return config @app.get("/health") async def api_health_check(): return JSONResponse(content={"status": "Service is running"}) @app.post("/predict/{plant_name}") async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)): """ API endpoint to predict plant disease from an uploaded image. Args: plant_name (str): The plant type (must match a key in `plant_disease_dict`). file (UploadFile): The image file uploaded by the user. Returns: JSON response with the predicted class. """ # Ensure the plant name is valid if plant_name not in plant_disease_dict: raise HTTPException(status_code=400, detail="Invalid plant name") # Construct the model path model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras") if plant_name == "Rice": model = load_model(model_path) else: model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) # Check if the model exists if not os.path.isfile(model_path): raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found") # Save uploaded file temporarily temp_path = f"temp_image_{file.filename}" with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Load model model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) # Load and preprocess the image img = image.load_img(temp_path, target_size=(224, 224)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input img_array = img_array / 255.0 # Normalize # Make prediction prediction = model.predict(img_array) predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)] return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class}) finally: # Clean up temporary file os.remove(temp_path) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)