navpan2's picture
Update main.py
601b2b5 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
import os
import requests
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
import shutil
import uvicorn
app = FastAPI()
# Directory where models are stored
MODEL_DIRECTORY = "dsanet_models"
# Temporary directory for uploaded files
TMP_DIR = os.getenv("TMP_DIR", "/app/temp")
os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists
# Plant disease class names
plant_disease_dict = {
"Rice": ['Blight', 'Brown_Spots'],
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"]
}
# Custom Self-Attention Layer
@tf.keras.utils.register_keras_serializable()
class SelfAttention(Layer):
def __init__(self, reduction_ratio=2, **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
n_channels = input_shape[-1] // self.reduction_ratio
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
super(SelfAttention, self).build(input_shape)
def call(self, inputs):
query = self.query_conv(inputs)
key = self.key_conv(inputs)
value = self.value_conv(inputs)
# Calculate attention scores
attention_scores = tf.matmul(query, key, transpose_b=True)
attention_scores = Softmax(axis=1)(attention_scores)
# Apply attention to values
attended_value = tf.matmul(attention_scores, value)
concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
return concatenated_output
def get_config(self):
config = super(SelfAttention, self).get_config()
config.update({"reduction_ratio": self.reduction_ratio})
return config
# **Load all models into memory at startup**
loaded_models = {}
def load_all_models():
"""
Load all models from the `dsanet_models` directory at startup.
"""
global loaded_models
for plant_name in plant_disease_dict.keys():
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
if os.path.isfile(model_path):
try:
if plant_name == "Rice":
loaded_models[plant_name] = load_model(model_path) # Load normally
else:
loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
print(f"✅ Model for {plant_name} loaded successfully!")
except Exception as e:
print(f"❌ Error loading model '{plant_name}': {e}")
else:
print(f"⚠ Warning: Model file '{model_path}' not found!")
# Load models at startup
load_all_models()
@app.get("/health")
async def api_health_check():
return JSONResponse(content={"status": "Service is running"})
@app.post("/predict/{plant_name}")
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
"""
API endpoint to predict plant disease from an uploaded image.
Args:
plant_name (str): The plant type (must match a key in `plant_disease_dict`).
file (UploadFile): The image file uploaded by the user.
Returns:
JSON response with the predicted class and additional details from an external API.
"""
# Ensure the plant name is valid
if len(plant_disease_dict.get(plant_name, [])) == 1:
single_disease = plant_disease_dict[plant_name][0] # Get the only class available
# 🔥 Fetch external data directly
try:
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}")
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
except Exception as e:
external_data = {"error": str(e)}
return JSONResponse(content={
"plantName": external_data.get("plantName", plant_name),
"botanicalName": external_data.get("botanicalName", "Unknown"),
"diseaseDesc": {
"diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease),
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")
},
"diseaseRemedyList": [
{
"title": remedy.get("title", "Unknown"),
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
} for remedy in external_data.get("diseaseRemedyList", [])
]
})
if plant_name not in loaded_models:
raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}")
# Save uploaded file temporarily
temp_path = os.path.join(TMP_DIR, file.filename)
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Retrieve the preloaded model
model = loaded_models[plant_name]
# Load and preprocess the image
img = image.load_img(temp_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
img_array = img_array / 255.0 # Normalize
# Make prediction
prediction = model.predict(img_array)
class_label = plant_disease_dict[plant_name][np.argmax(prediction)]
# Fetch additional data from external API
try:
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}")
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"}
except Exception as e:
external_data = {"error": str(e)}
return JSONResponse(content={
"plantName": external_data.get("plantName", plant_name),
"botanicalName": external_data.get("botanicalName", "Unknown"),
"diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label),
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"),
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")},
"diseaseRemedyList": [
{
"title": remedy.get("title", "Unknown"),
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"),
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available")
} for remedy in external_data.get("diseaseRemedyList", [])
]
})
# return JSONResponse(content={
# "plant": plant_name,
# "predicted_disease": class_label,
# "external_data": external_data
# })
finally:
# Clean up temporary file
os.remove(temp_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)