navpan2 commited on
Commit
5183e44
·
verified ·
1 Parent(s): 0149f27

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -39
main.py CHANGED
@@ -5,19 +5,18 @@ from fastapi import FastAPI, File, UploadFile
5
  from fastapi.responses import JSONResponse
6
  from io import BytesIO
7
  from PIL import Image
8
- from tensorflow.keras.preprocessing.image import img_to_array
9
- from tensorflow.keras.applications import resnet50
10
- from tensorflow.keras.applications.resnet50 import preprocess_input
11
- import uvicorn
12
  import tensorflow_addons as tfa
13
- # Load the h5 model
14
- custom_objects = {
15
- "Addons>CohenKappa": tfa.metrics.CohenKappa,
16
- }
17
 
18
  # Initialize FastAPI app
19
  app = FastAPI()
20
 
 
 
 
 
 
21
  # Model and class information
22
  model_path = "model.h5"
23
  class_labels = {
@@ -27,12 +26,12 @@ class_labels = {
27
  3: "Apple___healthy",
28
  4: "Background_without_leaves",
29
  5: "Blueberry___healthy",
30
- 6: "Cherry___Powdery_mildew",
31
- 7: "Cherry___healthy",
32
- 8: "Corn___Cercospora_leaf_spot_Gray_leaf_spot",
33
- 9: "Corn___Common_rust",
34
- 10: "Corn___Northern_Leaf_Blight",
35
- 11: "Corn___healthy",
36
  12: "Grape___Black_rot",
37
  13: "Grape___Esca_(Black_Measles)",
38
  14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
@@ -55,7 +54,7 @@ class_labels = {
55
  31: "Tomato___Late_blight",
56
  32: "Tomato___Leaf_Mold",
57
  33: "Tomato___Septoria_leaf_spot",
58
- 34: "Tomato___Spider_mites_Two-spotted_spider_mite",
59
  35: "Tomato___Target_Spot",
60
  36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
61
  37: "Tomato___Tomato_mosaic_virus",
@@ -64,31 +63,38 @@ class_labels = {
64
 
65
  # Load the model if it exists
66
  if os.path.exists(model_path):
67
- model = tf.keras.models.load_model('model.h5', custom_objects=custom_objects)
68
  print("Model loaded successfully.")
69
  else:
70
  print(f"Model file not found at {model_path}. Please upload the model.")
71
 
72
- # Function to predict crop disease in an image and return the class name
73
- def predict_image(image_data):
74
- try:
75
- # Load the image from binary data
76
- img = Image.open(BytesIO(image_data))
77
- # Resize the image to the target size
78
- img = img.resize((224, 224))
79
- # Convert image to array format for the model
80
- img_array = img_to_array(img)
81
- img_array = np.expand_dims(img_array, axis=0)
82
- img_array = preprocess_input(img_array)
83
 
84
- # Make prediction
85
- prediction = model.predict(img_array)
86
- predicted_class = np.argmax(prediction[0])
87
- class_name = class_labels.get(predicted_class, "Unknown") # Map to class name
88
- return class_name
89
- except Exception as e:
90
- print("Prediction error:", e)
91
- return "Error during prediction"
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Route for health check
94
  @app.get("/health")
@@ -99,12 +105,8 @@ async def api_health_check():
99
  @app.post("/predict")
100
  async def api_predict_image(file: UploadFile = File(...)):
101
  try:
102
- # Read the image file as binary data
103
  image_data = await file.read()
104
-
105
- # Call the prediction function with the image data
106
  prediction = predict_image(image_data)
107
-
108
  return JSONResponse(content={"prediction": prediction})
109
  except Exception as e:
110
  return JSONResponse(content={"error": str(e)})
 
5
  from fastapi.responses import JSONResponse
6
  from io import BytesIO
7
  from PIL import Image
8
+ from tensorflow.keras.preprocessing.image import img_to_array, load_img
 
 
 
9
  import tensorflow_addons as tfa
10
+ import uvicorn
 
 
 
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
+ # Register the custom object
16
+ custom_objects = {
17
+ "Addons>CohenKappa": tfa.metrics.CohenKappa,
18
+ }
19
+
20
  # Model and class information
21
  model_path = "model.h5"
22
  class_labels = {
 
26
  3: "Apple___healthy",
27
  4: "Background_without_leaves",
28
  5: "Blueberry___healthy",
29
+ 6: "Cherry_(including_sour)___Powdery_mildew",
30
+ 7: "Cherry_(including_sour)___healthy",
31
+ 8: "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
32
+ 9: "Corn_(maize)___Common_rust_",
33
+ 10: "Corn_(maize)___Northern_Leaf_Blight",
34
+ 11: "Corn_(maize)___healthy",
35
  12: "Grape___Black_rot",
36
  13: "Grape___Esca_(Black_Measles)",
37
  14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
 
54
  31: "Tomato___Late_blight",
55
  32: "Tomato___Leaf_Mold",
56
  33: "Tomato___Septoria_leaf_spot",
57
+ 34: "Tomato___Spider_mites Two-spotted_spider_mite",
58
  35: "Tomato___Target_Spot",
59
  36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
60
  37: "Tomato___Tomato_mosaic_virus",
 
63
 
64
  # Load the model if it exists
65
  if os.path.exists(model_path):
66
+ model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
67
  print("Model loaded successfully.")
68
  else:
69
  print(f"Model file not found at {model_path}. Please upload the model.")
70
 
71
+ # Function to preprocess input image
72
+ def preprocess_image(image_data, img_size=224):
73
+ img = Image.open(BytesIO(image_data))
74
+ img = img.resize((img_size, img_size))
75
+ img_array = img_to_array(img)
76
+ img_array = img_array / 255.0
77
+ img_array = np.expand_dims(img_array, axis=0)
78
+ return img_array
 
 
 
79
 
80
+ # Predict function
81
+ def predict_image(image_data):
82
+ preprocessed_image = preprocess_image(image_data)
83
+ predictions = model.predict(preprocessed_image)
84
+ class_idx = np.argmax(predictions, axis=1)[0]
85
+ confidence = predictions[0][class_idx]
86
+ class_label = class_labels.get(class_idx, "Unknown")
87
+ if class_label is None:
88
+ return {
89
+ "class_index": class_idx,
90
+ "class_label": "Skipped/Invalid",
91
+ "confidence": confidence,
92
+ }
93
+ return {
94
+ "class_index": class_idx,
95
+ "class_label": class_label,
96
+ "confidence": confidence,
97
+ }
98
 
99
  # Route for health check
100
  @app.get("/health")
 
105
  @app.post("/predict")
106
  async def api_predict_image(file: UploadFile = File(...)):
107
  try:
 
108
  image_data = await file.read()
 
 
109
  prediction = predict_image(image_data)
 
110
  return JSONResponse(content={"prediction": prediction})
111
  except Exception as e:
112
  return JSONResponse(content={"error": str(e)})