Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -5,19 +5,18 @@ from fastapi import FastAPI, File, UploadFile
|
|
5 |
from fastapi.responses import JSONResponse
|
6 |
from io import BytesIO
|
7 |
from PIL import Image
|
8 |
-
from tensorflow.keras.preprocessing.image import img_to_array
|
9 |
-
from tensorflow.keras.applications import resnet50
|
10 |
-
from tensorflow.keras.applications.resnet50 import preprocess_input
|
11 |
-
import uvicorn
|
12 |
import tensorflow_addons as tfa
|
13 |
-
|
14 |
-
custom_objects = {
|
15 |
-
"Addons>CohenKappa": tfa.metrics.CohenKappa,
|
16 |
-
}
|
17 |
|
18 |
# Initialize FastAPI app
|
19 |
app = FastAPI()
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
# Model and class information
|
22 |
model_path = "model.h5"
|
23 |
class_labels = {
|
@@ -27,12 +26,12 @@ class_labels = {
|
|
27 |
3: "Apple___healthy",
|
28 |
4: "Background_without_leaves",
|
29 |
5: "Blueberry___healthy",
|
30 |
-
6: "
|
31 |
-
7: "
|
32 |
-
8: "
|
33 |
-
9: "
|
34 |
-
10: "
|
35 |
-
11: "
|
36 |
12: "Grape___Black_rot",
|
37 |
13: "Grape___Esca_(Black_Measles)",
|
38 |
14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
|
@@ -55,7 +54,7 @@ class_labels = {
|
|
55 |
31: "Tomato___Late_blight",
|
56 |
32: "Tomato___Leaf_Mold",
|
57 |
33: "Tomato___Septoria_leaf_spot",
|
58 |
-
34: "
|
59 |
35: "Tomato___Target_Spot",
|
60 |
36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
|
61 |
37: "Tomato___Tomato_mosaic_virus",
|
@@ -64,31 +63,38 @@ class_labels = {
|
|
64 |
|
65 |
# Load the model if it exists
|
66 |
if os.path.exists(model_path):
|
67 |
-
model = tf.keras.models.load_model(
|
68 |
print("Model loaded successfully.")
|
69 |
else:
|
70 |
print(f"Model file not found at {model_path}. Please upload the model.")
|
71 |
|
72 |
-
# Function to
|
73 |
-
def
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
img_array = img_to_array(img)
|
81 |
-
img_array = np.expand_dims(img_array, axis=0)
|
82 |
-
img_array = preprocess_input(img_array)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# Route for health check
|
94 |
@app.get("/health")
|
@@ -99,12 +105,8 @@ async def api_health_check():
|
|
99 |
@app.post("/predict")
|
100 |
async def api_predict_image(file: UploadFile = File(...)):
|
101 |
try:
|
102 |
-
# Read the image file as binary data
|
103 |
image_data = await file.read()
|
104 |
-
|
105 |
-
# Call the prediction function with the image data
|
106 |
prediction = predict_image(image_data)
|
107 |
-
|
108 |
return JSONResponse(content={"prediction": prediction})
|
109 |
except Exception as e:
|
110 |
return JSONResponse(content={"error": str(e)})
|
|
|
5 |
from fastapi.responses import JSONResponse
|
6 |
from io import BytesIO
|
7 |
from PIL import Image
|
8 |
+
from tensorflow.keras.preprocessing.image import img_to_array, load_img
|
|
|
|
|
|
|
9 |
import tensorflow_addons as tfa
|
10 |
+
import uvicorn
|
|
|
|
|
|
|
11 |
|
12 |
# Initialize FastAPI app
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
# Register the custom object
|
16 |
+
custom_objects = {
|
17 |
+
"Addons>CohenKappa": tfa.metrics.CohenKappa,
|
18 |
+
}
|
19 |
+
|
20 |
# Model and class information
|
21 |
model_path = "model.h5"
|
22 |
class_labels = {
|
|
|
26 |
3: "Apple___healthy",
|
27 |
4: "Background_without_leaves",
|
28 |
5: "Blueberry___healthy",
|
29 |
+
6: "Cherry_(including_sour)___Powdery_mildew",
|
30 |
+
7: "Cherry_(including_sour)___healthy",
|
31 |
+
8: "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
|
32 |
+
9: "Corn_(maize)___Common_rust_",
|
33 |
+
10: "Corn_(maize)___Northern_Leaf_Blight",
|
34 |
+
11: "Corn_(maize)___healthy",
|
35 |
12: "Grape___Black_rot",
|
36 |
13: "Grape___Esca_(Black_Measles)",
|
37 |
14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
|
|
|
54 |
31: "Tomato___Late_blight",
|
55 |
32: "Tomato___Leaf_Mold",
|
56 |
33: "Tomato___Septoria_leaf_spot",
|
57 |
+
34: "Tomato___Spider_mites Two-spotted_spider_mite",
|
58 |
35: "Tomato___Target_Spot",
|
59 |
36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
|
60 |
37: "Tomato___Tomato_mosaic_virus",
|
|
|
63 |
|
64 |
# Load the model if it exists
|
65 |
if os.path.exists(model_path):
|
66 |
+
model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
|
67 |
print("Model loaded successfully.")
|
68 |
else:
|
69 |
print(f"Model file not found at {model_path}. Please upload the model.")
|
70 |
|
71 |
+
# Function to preprocess input image
|
72 |
+
def preprocess_image(image_data, img_size=224):
|
73 |
+
img = Image.open(BytesIO(image_data))
|
74 |
+
img = img.resize((img_size, img_size))
|
75 |
+
img_array = img_to_array(img)
|
76 |
+
img_array = img_array / 255.0
|
77 |
+
img_array = np.expand_dims(img_array, axis=0)
|
78 |
+
return img_array
|
|
|
|
|
|
|
79 |
|
80 |
+
# Predict function
|
81 |
+
def predict_image(image_data):
|
82 |
+
preprocessed_image = preprocess_image(image_data)
|
83 |
+
predictions = model.predict(preprocessed_image)
|
84 |
+
class_idx = np.argmax(predictions, axis=1)[0]
|
85 |
+
confidence = predictions[0][class_idx]
|
86 |
+
class_label = class_labels.get(class_idx, "Unknown")
|
87 |
+
if class_label is None:
|
88 |
+
return {
|
89 |
+
"class_index": class_idx,
|
90 |
+
"class_label": "Skipped/Invalid",
|
91 |
+
"confidence": confidence,
|
92 |
+
}
|
93 |
+
return {
|
94 |
+
"class_index": class_idx,
|
95 |
+
"class_label": class_label,
|
96 |
+
"confidence": confidence,
|
97 |
+
}
|
98 |
|
99 |
# Route for health check
|
100 |
@app.get("/health")
|
|
|
105 |
@app.post("/predict")
|
106 |
async def api_predict_image(file: UploadFile = File(...)):
|
107 |
try:
|
|
|
108 |
image_data = await file.read()
|
|
|
|
|
109 |
prediction = predict_image(image_data)
|
|
|
110 |
return JSONResponse(content={"prediction": prediction})
|
111 |
except Exception as e:
|
112 |
return JSONResponse(content={"error": str(e)})
|