from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse import tensorflow as tf import numpy as np import os import requests from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate import shutil import uvicorn app = FastAPI() # Directory where models are stored MODEL_DIRECTORY = "dsanet_models" # Temporary directory for uploaded files TMP_DIR = os.getenv("TMP_DIR", "/app/temp") os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists # Plant disease class names plant_disease_dict = { "Rice": ['Blight', 'Brown_Spots'], "Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'], "Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'], "Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'], "Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'], "Peach": ['Peach___Bacterial_spot', 'Peach___healthy'], "Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'], "Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'], "Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'], "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"] } # Custom Self-Attention Layer @tf.keras.utils.register_keras_serializable() class SelfAttention(Layer): def __init__(self, reduction_ratio=2, **kwargs): super(SelfAttention, self).__init__(**kwargs) self.reduction_ratio = reduction_ratio def build(self, input_shape): n_channels = input_shape[-1] // self.reduction_ratio self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) super(SelfAttention, self).build(input_shape) def call(self, inputs): query = self.query_conv(inputs) key = self.key_conv(inputs) value = self.value_conv(inputs) # Calculate attention scores attention_scores = tf.matmul(query, key, transpose_b=True) attention_scores = Softmax(axis=1)(attention_scores) # Apply attention to values attended_value = tf.matmul(attention_scores, value) concatenated_output = Concatenate(axis=-1)([inputs, attended_value]) return concatenated_output def get_config(self): config = super(SelfAttention, self).get_config() config.update({"reduction_ratio": self.reduction_ratio}) return config # **Load all models into memory at startup** loaded_models = {} def load_all_models(): """ Load all models from the `dsanet_models` directory at startup. """ global loaded_models for plant_name in plant_disease_dict.keys(): model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras") if os.path.isfile(model_path): try: if plant_name == "Rice": loaded_models[plant_name] = load_model(model_path) # Load normally else: loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) print(f"✅ Model for {plant_name} loaded successfully!") except Exception as e: print(f"❌ Error loading model '{plant_name}': {e}") else: print(f"⚠ Warning: Model file '{model_path}' not found!") # Load models at startup load_all_models() @app.get("/health") async def api_health_check(): return JSONResponse(content={"status": "Service is running"}) @app.post("/predict/{plant_name}") async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)): """ API endpoint to predict plant disease from an uploaded image. Args: plant_name (str): The plant type (must match a key in `plant_disease_dict`). file (UploadFile): The image file uploaded by the user. Returns: JSON response with the predicted class and additional details from an external API. """ # Ensure the plant name is valid if len(plant_disease_dict.get(plant_name, [])) == 1: single_disease = plant_disease_dict[plant_name][0] # Get the only class available # 🔥 Fetch external data directly try: response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}") external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} except Exception as e: external_data = {"error": str(e)} return JSONResponse(content={ "plantName": external_data.get("plantName", plant_name), "botanicalName": external_data.get("botanicalName", "Unknown"), "diseaseDesc": { "diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease), "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available") }, "diseaseRemedyList": [ { "title": remedy.get("title", "Unknown"), "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") } for remedy in external_data.get("diseaseRemedyList", []) ] }) if plant_name not in loaded_models: raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}") # Save uploaded file temporarily temp_path = os.path.join(TMP_DIR, file.filename) with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Retrieve the preloaded model model = loaded_models[plant_name] # Load and preprocess the image img = image.load_img(temp_path, target_size=(224, 224)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input img_array = img_array / 255.0 # Normalize # Make prediction prediction = model.predict(img_array) class_label = plant_disease_dict[plant_name][np.argmax(prediction)] # Fetch additional data from external API try: response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}") external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} except Exception as e: external_data = {"error": str(e)} return JSONResponse(content={ "plantName": external_data.get("plantName", plant_name), "botanicalName": external_data.get("botanicalName", "Unknown"), "diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label), "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")}, "diseaseRemedyList": [ { "title": remedy.get("title", "Unknown"), "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") } for remedy in external_data.get("diseaseRemedyList", []) ] }) # return JSONResponse(content={ # "plant": plant_name, # "predicted_disease": class_label, # "external_data": external_data # }) finally: # Clean up temporary file os.remove(temp_path) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)