navpan2 commited on
Commit
758f3f5
·
verified ·
1 Parent(s): ed070ba

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -0
main.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ import os
6
+ from tensorflow.keras.models import load_model
7
+ from tensorflow.keras.preprocessing import image
8
+ from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
9
+ import shutil
10
+ import uvicorn
11
+ import requests
12
+
13
+ app = FastAPI()
14
+
15
+ # Directory where models are stored
16
+ MODEL_DIRECTORY = "dsanet_models"
17
+
18
+ # Plant disease class names
19
+ plant_disease_dict = {
20
+ "Rice": ['Blight', 'Brown_Spots'],
21
+ "Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
22
+ 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
23
+ 'Tomato___Spider_mites Two-spotted_spider_mite',
24
+ 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
25
+ 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'],
26
+ "Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'],
27
+ "Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'],
28
+ "Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'],
29
+ "Peach": ['Peach___Bacterial_spot', 'Peach___healthy'],
30
+ "Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)',
31
+ 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'],
32
+ "Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'],
33
+ "Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'],
34
+ "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
35
+ 'Corn___Northern_Leaf_Blight', 'Corn___healthy']
36
+ }
37
+
38
+ # Custom Self-Attention Layer
39
+ @tf.keras.utils.register_keras_serializable()
40
+ class SelfAttention(Layer):
41
+ def __init__(self, reduction_ratio=2, **kwargs):
42
+ super(SelfAttention, self).__init__(**kwargs)
43
+ self.reduction_ratio = reduction_ratio
44
+
45
+ def build(self, input_shape):
46
+ n_channels = input_shape[-1] // self.reduction_ratio
47
+ self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
48
+ self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
49
+ self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False)
50
+ super(SelfAttention, self).build(input_shape)
51
+
52
+ def call(self, inputs):
53
+ query = self.query_conv(inputs)
54
+ key = self.key_conv(inputs)
55
+ value = self.value_conv(inputs)
56
+
57
+ # Calculate attention scores
58
+ attention_scores = tf.matmul(query, key, transpose_b=True)
59
+ attention_scores = Softmax(axis=1)(attention_scores)
60
+
61
+ # Apply attention to values
62
+ attended_value = tf.matmul(attention_scores, value)
63
+ concatenated_output = Concatenate(axis=-1)([inputs, attended_value])
64
+ return concatenated_output
65
+
66
+ def get_config(self):
67
+ config = super(SelfAttention, self).get_config()
68
+ config.update({"reduction_ratio": self.reduction_ratio})
69
+ return config
70
+
71
+ @app.get("/health")
72
+ async def api_health_check():
73
+ return JSONResponse(content={"status": "Service is running"})
74
+ @app.post("/predict/{plant_name}")
75
+ async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
76
+ """
77
+ API endpoint to predict plant disease from an uploaded image.
78
+
79
+ Args:
80
+ plant_name (str): The plant type (must match a key in `plant_disease_dict`).
81
+ file (UploadFile): The image file uploaded by the user.
82
+
83
+ Returns:
84
+ JSON response with the predicted class.
85
+ """
86
+ # Ensure the plant name is valid
87
+ if plant_name not in plant_disease_dict:
88
+ raise HTTPException(status_code=400, detail="Invalid plant name")
89
+
90
+ # Construct the model path
91
+
92
+ if plant_name == "Rice":
93
+ model = load_model(model_path)
94
+ else:
95
+ model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
96
+
97
+
98
+ # Check if the model exists
99
+ if not os.path.isfile(model_path):
100
+ raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found")
101
+
102
+ # Save uploaded file temporarily
103
+ temp_path = f"temp_image_{file.filename}"
104
+ with open(temp_path, "wb") as buffer:
105
+ shutil.copyfileobj(file.file, buffer)
106
+
107
+ try:
108
+ # Load model
109
+ model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
110
+
111
+ # Load and preprocess the image
112
+ img = image.load_img(temp_path, target_size=(224, 224))
113
+ img_array = image.img_to_array(img)
114
+ img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input
115
+ img_array = img_array / 255.0 # Normalize
116
+
117
+ # Make prediction
118
+ prediction = model.predict(img_array)
119
+ predicted_class = plant_disease_dict[plant_name][np.argmax(prediction)]
120
+
121
+ return JSONResponse(content={"plant": plant_name, "predicted_disease": predicted_class})
122
+ finally:
123
+ # Clean up temporary file
124
+ os.remove(temp_path)
125
+ if __name__ == "__main__":
126
+ uvicorn.run(app, host="0.0.0.0", port=7860)