navpan2 commited on
Commit
98f3b15
·
verified ·
1 Parent(s): c47b658

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -0
main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from tensorflow.keras.preprocessing.image import img_to_array
9
+ from tensorflow.keras.applications import resnet50
10
+ from tensorflow.keras.applications.resnet50 import preprocess_input
11
+ import uvicorn
12
+
13
+ # Initialize FastAPI app
14
+ app = FastAPI()
15
+
16
+ # Model and class information
17
+ model_path = "model.h5"
18
+ class_labels = {
19
+ 0: "Apple___Apple_scab",
20
+ 1: "Apple___Black_rot",
21
+ 2: "Apple___Cedar_apple_rust",
22
+ 3: "Apple___healthy",
23
+ 4: "Background_without_leaves",
24
+ 5: "Blueberry___healthy",
25
+ 6: "Cherry___Powdery_mildew",
26
+ 7: "Cherry___healthy",
27
+ 8: "Corn___Cercospora_leaf_spot_Gray_leaf_spot",
28
+ 9: "Corn___Common_rust",
29
+ 10: "Corn___Northern_Leaf_Blight",
30
+ 11: "Corn___healthy",
31
+ 12: "Grape___Black_rot",
32
+ 13: "Grape___Esca_(Black_Measles)",
33
+ 14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
34
+ 15: "Grape___healthy",
35
+ 16: "Orange___Haunglongbing_(Citrus_greening)",
36
+ 17: "Peach___Bacterial_spot",
37
+ 18: "Peach___healthy",
38
+ 19: "Pepper,_bell___Bacterial_spot",
39
+ 20: "Pepper,_bell___healthy",
40
+ 21: "Potato___Early_blight",
41
+ 22: "Potato___Late_blight",
42
+ 23: "Potato___healthy",
43
+ 24: "Raspberry___healthy",
44
+ 25: "Soybean___healthy",
45
+ 26: "Squash___Powdery_mildew",
46
+ 27: "Strawberry___Leaf_scorch",
47
+ 28: "Strawberry___healthy",
48
+ 29: "Tomato___Bacterial_spot",
49
+ 30: "Tomato___Early_blight",
50
+ 31: "Tomato___Late_blight",
51
+ 32: "Tomato___Leaf_Mold",
52
+ 33: "Tomato___Septoria_leaf_spot",
53
+ 34: "Tomato___Spider_mites_Two-spotted_spider_mite",
54
+ 35: "Tomato___Target_Spot",
55
+ 36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
56
+ 37: "Tomato___Tomato_mosaic_virus",
57
+ 38: "Tomato___healthy"
58
+ }
59
+
60
+ # Load the model if it exists
61
+ if os.path.exists(model_path):
62
+ model = tf.keras.models.load_model(model_path)
63
+ print("Model loaded successfully.")
64
+ else:
65
+ print(f"Model file not found at {model_path}. Please upload the model.")
66
+
67
+ # Function to predict crop disease in an image and return the class name
68
+ def predict_image(image_data):
69
+ try:
70
+ # Load the image from binary data
71
+ img = Image.open(BytesIO(image_data))
72
+ # Resize the image to the target size
73
+ img = img.resize((224, 224))
74
+ # Convert image to array format for the model
75
+ img_array = img_to_array(img)
76
+ img_array = np.expand_dims(img_array, axis=0)
77
+ img_array = preprocess_input(img_array)
78
+
79
+ # Make prediction
80
+ prediction = model.predict(img_array)
81
+ predicted_class = np.argmax(prediction[0])
82
+ class_name = class_labels.get(predicted_class, "Unknown") # Map to class name
83
+ return class_name
84
+ except Exception as e:
85
+ print("Prediction error:", e)
86
+ return "Error during prediction"
87
+
88
+ # Route for health check
89
+ @app.get("/health")
90
+ async def api_health_check():
91
+ return JSONResponse(content={"status": "Service is running"})
92
+
93
+ # Route for prediction using image via API
94
+ @app.post("/predict")
95
+ async def api_predict_image(file: UploadFile = File(...)):
96
+ try:
97
+ # Read the image file as binary data
98
+ image_data = await file.read()
99
+
100
+ # Call the prediction function with the image data
101
+ prediction = predict_image(image_data)
102
+
103
+ return JSONResponse(content={"prediction": prediction})
104
+ except Exception as e:
105
+ return JSONResponse(content={"error": str(e)})
106
+
107
+ # Run the FastAPI app
108
+ if __name__ == "__main__":
109
+ uvicorn.run(app, host="0.0.0.0", port=7860)