File size: 1,860 Bytes
3d95070
 
f66c549
3d95070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from keras.models import load_model
import numpy as np
from PIL import Image
import io

app = FastAPI()

# Load the Keras model
model = load_model('keras_model.h5')  # Replace 'your_model.h5' with the path to your .h5 file

# Function to preprocess the input image
def preprocess_image(img):
    img = img.resize((224, 224))  # Assuming input size of 224x224
    img_array = np.array(img)
    img_array = img_array.astype('float32') / 255  # Normalization
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

# Define a function to predict the class of an image
def predict_class(img):
    processed_image = preprocess_image(img)
    prediction = model.predict(processed_image)
    return prediction

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    img = Image.open(io.BytesIO(contents))
    prediction = predict_class(img)
    
    # Assuming your model output is a list of probabilities for each class
    # You may need to modify this based on your model's output
    prediction = prediction.tolist()[0]
    
    # Assuming you have two classes: Blight disease and Powdery mildew
    # Modify this based on your actual class names
    class_names = ["Blight disease on grape leaves", "Powdery mildew on grapes"]
    result = {"prediction": class_names[np.argmax(prediction)], "probabilities": prediction}
    return result

# Allow CORS (Cross-Origin Resource Sharing) for all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

# Handle OPTIONS requests
@app.options("/predict/")
async def options_predict():
    return {"methods": ["POST"]}