Spaces:
Running
Running
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
import tensorflow as tf
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
from tensorflow.keras.preprocessing import image
|
8 |
+
from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
|
9 |
+
import shutil
|
10 |
+
import uvicorn
|
11 |
+
import requests
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
# Directory where models are stored
|
16 |
+
MODEL_DIRECTORY = "dsanet_models"
|
17 |
+
|
18 |
+
# Plant disease class names
|
19 |
+
plant_disease_dict = {
|
20 |
+
"Rice": ['Blight', 'Brown_Spots'],
|
21 |
+
"Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
|
22 |
+
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
|
23 |
+
'Tomato___Spider_mites Two-spotted_spider_mite',
|
24 |
+
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
|
25 |
+
'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
|
26 |
+
"Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
|
27 |
+
"Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
|
28 |
+
"Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
|
29 |
+
"Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
|
30 |
+
"Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
|
31 |
+
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
|
32 |
+
"Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
|
33 |
+
"Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
|
34 |
+
"Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
|
35 |
+
'Corn___Northern_Leaf_Blight', 'Corn___healthy']
|
36 |
+
}
|
37 |
+
|
38 |
+
# Custom Self-Attention Layer
|
39 |
+
@tf.keras.utils.register_keras_serializable()
|
40 |
+
class SelfAttention(Layer):
|
41 |
+
def __init__(self, reduction_ratio=2, **kwargs):
|
42 |
+
super(SelfAttention, self).__init__(**kwargs)
|
43 |
+
self.reduction_ratio = reduction_ratio
|
44 |
+
|
45 |
+
def build(self, input_shape):
|
46 |
+
n_channels = input_shape[-1] // self.reduction_ratio
|
47 |
+
self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
|
48 |
+
self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
|
49 |
+
self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
|
50 |
+
super(SelfAttention, self).build(input_shape)
|
51 |
+
|
52 |
+
def call(self, inputs):
|
53 |
+
query = self.query_conv(inputs)
|
54 |
+
key = self.key_conv(inputs)
|
55 |
+
value = self.value_conv(inputs)
|
56 |
+
|
57 |
+
# Calculate attention scores
|
58 |
+
attention_scores = tf.matmul(query, key, transpose_b=True)
|
59 |
+
attention_scores = Softmax(axis=1)(attention_scores)
|
60 |
+
|
61 |
+
# Apply attention to values
|
62 |
+
attended_value = tf.matmul(attention_scores, value)
|
63 |
+
concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
|
64 |
+
return concatenated_output
|
65 |
+
|
66 |
+
def get_config(self):
|
67 |
+
config = super(SelfAttention, self).get_config()
|
68 |
+
config.update({"reduction_ratio": self.reduction_ratio})
|
69 |
+
return config
|
70 |
+
|
71 |
+
@app.get("/health")
|
72 |
+
async def api_health_check():
|
73 |
+
return JSONResponse(content={"status": "Service is running"})
|
74 |
+
@app.post("/predict/{plant_name}")
|
75 |
+
async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
|
76 |
+
"""
|
77 |
+
API endpoint to predict plant disease from an uploaded image.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
plant_name (str): The plant type (must match a key in `plant_disease_dict`).
|
81 |
+
file (UploadFile): The image file uploaded by the user.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
JSON response with the predicted class.
|
85 |
+
"""
|
86 |
+
# Ensure the plant name is valid
|
87 |
+
if plant_name not in plant_disease_dict:
|
88 |
+
raise HTTPException(status_code=400, detail="Invalid plant name")
|
89 |
+
|
90 |
+
# Construct the model path
|
91 |
+
|
92 |
+
if plant_name == "Rice":
|
93 |
+
model = load_model(model_path)
|
94 |
+
else:
|
95 |
+
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
|
96 |
+
|
97 |
+
|
98 |
+
# Check if the model exists
|
99 |
+
if not os.path.isfile(model_path):
|
100 |
+
raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found")
|
101 |
+
|
102 |
+
# Save uploaded file temporarily
|
103 |
+
temp_path = f"temp_image_{file.filename}"
|
104 |
+
with open(temp_path, "wb") as buffer:
|
105 |
+
shutil.copyfileobj(file.file, buffer)
|
106 |
+
|
107 |
+
try:
|
108 |
+
# Load model
|
109 |
+
model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
|
110 |
+
|
111 |
+
# Load and preprocess the image
|
112 |
+
img = image.load_img(temp_path, target_size=(224, 224))
|
113 |
+
img_array = image.img_to_array(img)
|
114 |
+
img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
|
115 |
+
img_array = img_array / 255.0 # Normalize
|
116 |
+
|
117 |
+
# Make prediction
|
118 |
+
prediction = model.predict(img_array)
|
119 |
+
predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)]
|
120 |
+
|
121 |
+
return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class})
|
122 |
+
finally:
|
123 |
+
# Clean up temporary file
|
124 |
+
os.remove(temp_path)
|
125 |
+
if __name__ == "__main__":
|
126 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|