Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
import tensorflow as tf | |
import numpy as np | |
import os | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image | |
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate | |
import shutil | |
import uvicorn | |
import requests | |
app = FastAPI() | |
# Directory where models are stored | |
MODEL_DIRECTORY = "dsanet_models" | |
# Plant disease class names | |
plant_disease_dict = { | |
"Rice": ['Blight', 'Brown_Spots'], | |
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', | |
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', | |
'Tomato___Spider_mites Two-spotted_spider_mite', | |
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'], | |
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'], | |
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'], | |
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'], | |
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'], | |
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)', | |
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'], | |
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'], | |
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'], | |
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', | |
'Corn___Northern_Leaf_Blight', 'Corn___healthy'] | |
} | |
# Custom Self-Attention Layer | |
class SelfAttention(Layer): | |
def __init__(self, reduction_ratio=2, **kwargs): | |
super(SelfAttention, self).__init__(**kwargs) | |
self.reduction_ratio = reduction_ratio | |
def build(self, input_shape): | |
n_channels = input_shape[-1] // self.reduction_ratio | |
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
super(SelfAttention, self).build(input_shape) | |
def call(self, inputs): | |
query = self.query_conv(inputs) | |
key = self.key_conv(inputs) | |
value = self.value_conv(inputs) | |
# Calculate attention scores | |
attention_scores = tf.matmul(query, key, transpose_b=True) | |
attention_scores = Softmax(axis=1)(attention_scores) | |
# Apply attention to values | |
attended_value = tf.matmul(attention_scores, value) | |
concatenated_output = Concatenate(axis=-1)([inputs, attended_value]) | |
return concatenated_output | |
def get_config(self): | |
config = super(SelfAttention, self).get_config() | |
config.update({"reduction_ratio": self.reduction_ratio}) | |
return config | |
async def api_health_check(): | |
return JSONResponse(content={"status": "Service is running"}) | |
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)): | |
""" | |
API endpoint to predict plant disease from an uploaded image. | |
Args: | |
plant_name (str): The plant type (must match a key in `plant_disease_dict`). | |
file (UploadFile): The image file uploaded by the user. | |
Returns: | |
JSON response with the predicted class. | |
""" | |
# Ensure the plant name is valid | |
if plant_name not in plant_disease_dict: | |
raise HTTPException(status_code=400, detail="Invalid plant name") | |
# Construct the model path | |
model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras") | |
if plant_name == "Rice": | |
model = load_model(model_path) | |
else: | |
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) | |
# Check if the model exists | |
if not os.path.isfile(model_path): | |
raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found") | |
# Save uploaded file temporarily | |
temp_path = f"temp_image_{file.filename}" | |
with open(temp_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
try: | |
# Load model | |
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) | |
# Load and preprocess the image | |
img = image.load_img(temp_path, target_size=(224, 224)) | |
img_array = image.img_to_array(img) | |
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input | |
img_array = img_array / 255.0 # Normalize | |
# Make prediction | |
prediction = model.predict(img_array) | |
predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)] | |
return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class}) | |
finally: | |
# Clean up temporary file | |
os.remove(temp_path) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |