navpan2 commited on
Commit
ffe6f85
·
verified ·
1 Parent(s): 984e213

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -23
main.py CHANGED
@@ -8,13 +8,16 @@ from tensorflow.keras.preprocessing import image
8
  from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
9
  import shutil
10
  import uvicorn
11
- import requests
12
 
13
  app = FastAPI()
14
 
15
  # Directory where models are stored
16
  MODEL_DIRECTORY = "dsanet_models"
17
 
 
 
 
 
18
  # Plant disease class names
19
  plant_disease_dict = {
20
  "Rice": ['Blight', 'Brown_Spots'],
@@ -34,8 +37,7 @@ plant_disease_dict = {
34
  "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
35
  'Corn___Northern_Leaf_Blight', 'Corn___healthy']
36
  }
37
- TMP_DIR = os.getenv("TMP_DIR", "/app/temp")
38
- os.makedirs(TMP_DIR, exist_ok=True)
39
  # Custom Self-Attention Layer
40
  @tf.keras.utils.register_keras_serializable()
41
  class SelfAttention(Layer):
@@ -69,9 +71,39 @@ class SelfAttention(Layer):
69
  config.update({"reduction_ratio": self.reduction_ratio})
70
  return config
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @app.get("/health")
73
  async def api_health_check():
74
  return JSONResponse(content={"status": "Service is running"})
 
 
75
  @app.post("/predict/{plant_name}")
76
  async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
77
  """
@@ -85,32 +117,17 @@ async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
85
  JSON response with the predicted class.
86
  """
87
  # Ensure the plant name is valid
88
- if plant_name not in plant_disease_dict:
89
- raise HTTPException(status_code=400, detail="Invalid plant name")
90
-
91
- # Construct the model path
92
- model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
93
- if plant_name == "Rice":
94
- model = load_model(model_path)
95
- else:
96
- model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
97
-
98
-
99
- # Check if the model exists
100
- if not os.path.isfile(model_path):
101
- raise HTTPException(status_code=404, detail=f"Model file '{plant_name}_model.keras' not found")
102
 
103
  # Save uploaded file temporarily
104
-
105
-
106
- # Define the temp file path
107
  temp_path = os.path.join(TMP_DIR, file.filename)
108
  with open(temp_path, "wb") as buffer:
109
  shutil.copyfileobj(file.file, buffer)
110
 
111
  try:
112
- # Load model
113
- model = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
114
 
115
  # Load and preprocess the image
116
  img = image.load_img(temp_path, target_size=(224, 224))
@@ -126,5 +143,7 @@ async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
126
  finally:
127
  # Clean up temporary file
128
  os.remove(temp_path)
 
 
129
  if __name__ == "__main__":
130
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate
9
  import shutil
10
  import uvicorn
 
11
 
12
  app = FastAPI()
13
 
14
  # Directory where models are stored
15
  MODEL_DIRECTORY = "dsanet_models"
16
 
17
+ # Temporary directory for uploaded files
18
+ TMP_DIR = os.getenv("TMP_DIR", "/app/temp")
19
+ os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists
20
+
21
  # Plant disease class names
22
  plant_disease_dict = {
23
  "Rice": ['Blight', 'Brown_Spots'],
 
37
  "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust',
38
  'Corn___Northern_Leaf_Blight', 'Corn___healthy']
39
  }
40
+
 
41
  # Custom Self-Attention Layer
42
  @tf.keras.utils.register_keras_serializable()
43
  class SelfAttention(Layer):
 
71
  config.update({"reduction_ratio": self.reduction_ratio})
72
  return config
73
 
74
+
75
+ # **Load all models into memory at startup**
76
+ loaded_models = {}
77
+
78
+ def load_all_models():
79
+ """
80
+ Load all models from the `dsanet_models` directory at startup.
81
+ """
82
+ global loaded_models
83
+ for plant_name in plant_disease_dict.keys():
84
+ model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras")
85
+
86
+ if os.path.isfile(model_path):
87
+ try:
88
+ if plant_name == "Rice":
89
+ loaded_models[plant_name] = load_model(model_path) # Load normally
90
+ else:
91
+ loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention})
92
+ print(f"✅ Model for {plant_name} loaded successfully!")
93
+ except Exception as e:
94
+ print(f"❌ Error loading model '{plant_name}': {e}")
95
+ else:
96
+ print(f"⚠ Warning: Model file '{model_path}' not found!")
97
+
98
+ # Load models at startup
99
+ load_all_models()
100
+
101
+
102
  @app.get("/health")
103
  async def api_health_check():
104
  return JSONResponse(content={"status": "Service is running"})
105
+
106
+
107
  @app.post("/predict/{plant_name}")
108
  async def predict_plant_disease(plant_name: str, file: UploadFile = File(...)):
109
  """
 
117
  JSON response with the predicted class.
118
  """
119
  # Ensure the plant name is valid
120
+ if plant_name not in loaded_models:
121
+ raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}")
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Save uploaded file temporarily
 
 
 
124
  temp_path = os.path.join(TMP_DIR, file.filename)
125
  with open(temp_path, "wb") as buffer:
126
  shutil.copyfileobj(file.file, buffer)
127
 
128
  try:
129
+ # Retrieve the preloaded model
130
+ model = loaded_models[plant_name]
131
 
132
  # Load and preprocess the image
133
  img = image.load_img(temp_path, target_size=(224, 224))
 
143
  finally:
144
  # Clean up temporary file
145
  os.remove(temp_path)
146
+
147
+
148
  if __name__ == "__main__":
149
+ uvicorn.run(app, host="0.0.0.0", port=7860)