Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
import tensorflow as tf | |
import numpy as np | |
import os | |
import requests | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate | |
import shutil | |
import uvicorn | |
app = FastAPI() | |
# Directory where models are stored | |
MODEL_DIRECTORY = "dsanet_models" | |
# Temporary directory for uploaded files | |
TMP_DIR = os.getenv("TMP_DIR", "/app/temp") | |
os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists | |
# Plant disease class names | |
plant_disease_dict = { | |
"Rice": ['Blight', 'Brown_Spots'], | |
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', | |
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', | |
'Tomato___Spider_mites Two-spotted_spider_mite', | |
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'], | |
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'], | |
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'], | |
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'], | |
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'], | |
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)', | |
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'], | |
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'], | |
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'], | |
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', | |
'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"] | |
} | |
# Custom Self-Attention Layer | |
class SelfAttention(Layer): | |
def __init__(self, reduction_ratio=2, **kwargs): | |
super(SelfAttention, self).__init__(**kwargs) | |
self.reduction_ratio = reduction_ratio | |
def build(self, input_shape): | |
n_channels = input_shape[-1] // self.reduction_ratio | |
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
super(SelfAttention, self).build(input_shape) | |
def call(self, inputs): | |
query = self.query_conv(inputs) | |
key = self.key_conv(inputs) | |
value = self.value_conv(inputs) | |
# Calculate attention scores | |
attention_scores = tf.matmul(query, key, transpose_b=True) | |
attention_scores = Softmax(axis=1)(attention_scores) | |
# Apply attention to values | |
attended_value = tf.matmul(attention_scores, value) | |
concatenated_output = Concatenate(axis=-1)([inputs, attended_value]) | |
return concatenated_output | |
def get_config(self): | |
config = super(SelfAttention, self).get_config() | |
config.update({"reduction_ratio": self.reduction_ratio}) | |
return config | |
# **Load all models into memory at startup** | |
loaded_models = {} | |
def load_all_models(): | |
""" | |
Load all models from the `dsanet_models` directory at startup. | |
""" | |
global loaded_models | |
for plant_name in plant_disease_dict.keys(): | |
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras") | |
if os.path.isfile(model_path): | |
try: | |
if plant_name == "Rice": | |
loaded_models[plant_name] = load_model(model_path) # Load normally | |
else: | |
loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) | |
print(f"✅ Model for {plant_name} loaded successfully!") | |
except Exception as e: | |
print(f"❌ Error loading model '{plant_name}': {e}") | |
else: | |
print(f"⚠ Warning: Model file '{model_path}' not found!") | |
# Load models at startup | |
load_all_models() | |
async def api_health_check(): | |
return JSONResponse(content={"status": "Service is running"}) | |
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)): | |
""" | |
API endpoint to predict plant disease from an uploaded image. | |
Args: | |
plant_name (str): The plant type (must match a key in `plant_disease_dict`). | |
file (UploadFile): The image file uploaded by the user. | |
Returns: | |
JSON response with the predicted class and additional details from an external API. | |
""" | |
# Ensure the plant name is valid | |
if len(plant_disease_dict.get(plant_name, [])) == 1: | |
single_disease = plant_disease_dict[plant_name][0] # Get the only class available | |
# 🔥 Fetch external data directly | |
try: | |
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}") | |
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
except Exception as e: | |
external_data = {"error": str(e)} | |
return JSONResponse(content={ | |
"plantName": external_data.get("plantName", plant_name), | |
"botanicalName": external_data.get("botanicalName", "Unknown"), | |
"diseaseDesc": { | |
"diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease), | |
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), | |
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available") | |
}, | |
"diseaseRemedyList": [ | |
{ | |
"title": remedy.get("title", "Unknown"), | |
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), | |
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") | |
} for remedy in external_data.get("diseaseRemedyList", []) | |
] | |
}) | |
if plant_name not in loaded_models: | |
raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}") | |
# Save uploaded file temporarily | |
temp_path = os.path.join(TMP_DIR, file.filename) | |
with open(temp_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
try: | |
# Retrieve the preloaded model | |
model = loaded_models[plant_name] | |
# Load and preprocess the image | |
img = image.load_img(temp_path, target_size=(224, 224)) | |
img_array = image.img_to_array(img) | |
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input | |
img_array = img_array / 255.0 # Normalize | |
# Make prediction | |
prediction = model.predict(img_array) | |
class_label = plant_disease_dict[plant_name][np.argmax(prediction)] | |
# Fetch additional data from external API | |
try: | |
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}") | |
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
except Exception as e: | |
external_data = {"error": str(e)} | |
return JSONResponse(content={ | |
"plantName": external_data.get("plantName", plant_name), | |
"botanicalName": external_data.get("botanicalName", "Unknown"), | |
"diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label), | |
"symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), | |
"diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")}, | |
"diseaseRemedyList": [ | |
{ | |
"title": remedy.get("title", "Unknown"), | |
"diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), | |
"diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") | |
} for remedy in external_data.get("diseaseRemedyList", []) | |
] | |
}) | |
# return JSONResponse(content={ | |
# "plant": plant_name, | |
# "predicted_disease": class_label, | |
# "external_data": external_data | |
# }) | |
finally: | |
# Clean up temporary file | |
os.remove(temp_path) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |