anand_plant2 / server.py
ranchopanda0's picture
Create server.py
56dbdf2 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
import uvicorn
import io
import numpy as np
from PIL import Image, UnidentifiedImageError
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import json
import logging
logging.basicConfig(level=logging.INFO)
# Load Model
model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Load Treatments
disease_treatments = {}
try:
with open("disease_treatments.json", "r") as file:
disease_treatments = json.load(file)
except FileNotFoundError:
logging.warning("Treatment database file not found. Using default treatments.")
disease_treatments = {"Healthy": "No disease detected. Maintain proper plant care."}
app = FastAPI()
def predict(image):
try:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
treatment = disease_treatments.get(predicted_label, "No treatment info available.")
return {"Disease": predicted_label, "Treatment": treatment}
except Exception as e:
logging.error(f"Prediction failed: {str(e)}")
return {"error": f"Prediction failed: {str(e)}"}
@app.post("/predict")
async def api_predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_array = np.array(image)
prediction = predict(image_array)
return {"prediction": prediction}
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)