Spaces:
Running
Running
import os | |
import numpy as np | |
import tensorflow as tf | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from io import BytesIO | |
from PIL import Image | |
from tensorflow.keras.preprocessing.image import img_to_array | |
import tensorflow_addons as tfa | |
import uvicorn | |
import requests | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Register the custom object | |
custom_objects = { | |
"Addons>CohenKappa": tfa.metrics.CohenKappa, | |
} | |
# Model and class information | |
model_path = "model.h5" | |
class_labels = { | |
0: "Apple___Apple_scab", | |
1: "Apple___Black_rot", | |
2: "Apple___Cedar_apple_rust", | |
3: "Apple___healthy", | |
4: "Background_without_leaves", | |
5: "Blueberry___healthy", | |
6: "Cherry___Powdery_mildew", | |
7: "Cherry___healthy", | |
8: "Corn___Cercospora_leaf_spot Gray_leaf_spot", | |
9: "Corn___Common_rust_", | |
10: "Corn___Northern_Leaf_Blight", | |
11: "Corn___healthy", | |
12: "Grape___Black_rot", | |
13: "Grape___Esca_(Black_Measles)", | |
14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", | |
15: "Grape___healthy", | |
16: "Orange___Haunglongbing_(Citrus_greening)", | |
17: "Peach___Bacterial_spot", | |
18: "Peach___healthy", | |
19: "Pepper,_bell___Bacterial_spot", | |
20: "Pepper,_bell___healthy", | |
21: "Potato___Early_blight", | |
22: "Potato___Late_blight", | |
23: "Potato___healthy", | |
24: "Raspberry___healthy", | |
25: "Soybean___healthy", | |
26: "Squash___Powdery_mildew", | |
27: "Strawberry___Leaf_scorch", | |
28: "Strawberry___healthy", | |
29: "Tomato___Bacterial_spot", | |
30: "Tomato___Early_blight", | |
31: "Tomato___Late_blight", | |
32: "Tomato___Leaf_Mold", | |
33: "Tomato___Septoria_leaf_spot", | |
34: "Tomato___Spider_mites Two-spotted_spider_mite", | |
35: "Tomato___Target_Spot", | |
36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus", | |
37: "Tomato___Tomato_mosaic_virus", | |
38: "Tomato___healthy" | |
} | |
# Load the model if it exists | |
if os.path.exists(model_path): | |
model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) | |
print("Model loaded successfully.") | |
else: | |
print(f"Model file not found at {model_path}. Please upload the model.") | |
# Function to preprocess input image | |
def preprocess_image(image_data, img_size=224): | |
img = Image.open(BytesIO(image_data)) | |
img = img.resize((img_size, img_size)) | |
img_array = img_to_array(img) | |
img_array = img_array / 255.0 | |
img_array = np.expand_dims(img_array, axis=0) | |
return img_array | |
# Predict function | |
def predict_image(image_data): | |
preprocessed_image = preprocess_image(image_data) | |
predictions = model.predict(preprocessed_image) | |
class_idx = int(np.argmax(predictions, axis=1)[0]) # Convert to int for JSON serialization | |
confidence = float(predictions[0][class_idx]) # Convert to float for JSON serialization | |
class_label = class_labels.get(class_idx, "Unknown") | |
# Fetch additional data from external API | |
try: | |
response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}") | |
external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
except Exception as e: | |
external_data = {"error": str(e)} | |
return external_data | |
# Route for health check | |
async def api_health_check(): | |
return JSONResponse(content={"status": "Service is running"}) | |
# Route for prediction using image via API | |
async def api_predict_image(file: UploadFile = File(...)): | |
try: | |
image_data = await file.read() | |
prediction = predict_image(image_data) | |
return JSONResponse(content={"prediction": prediction}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}) | |
# Run the FastAPI app | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |