from fastapi import FastAPI, UploadFile, File, HTTPException import uvicorn import io import numpy as np from PIL import Image, UnidentifiedImageError import torch from transformers import AutoImageProcessor, AutoModelForImageClassification import json import logging logging.basicConfig(level=logging.INFO) # Load Model model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) model = AutoModelForImageClassification.from_pretrained(model_name) # Load Treatments disease_treatments = {} try: with open("disease_treatments.json", "r") as file: disease_treatments = json.load(file) except FileNotFoundError: logging.warning("Treatment database file not found. Using default treatments.") disease_treatments = {"Healthy": "No disease detected. Maintain proper plant care."} app = FastAPI() def predict(image): try: inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_class_idx = outputs.logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] treatment = disease_treatments.get(predicted_label, "No treatment info available.") return {"Disease": predicted_label, "Treatment": treatment} except Exception as e: logging.error(f"Prediction failed: {str(e)}") return {"error": f"Prediction failed: {str(e)}"} @app.post("/predict") async def api_predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") image_array = np.array(image) prediction = predict(image_array) return {"prediction": prediction} except UnidentifiedImageError: raise HTTPException(status_code=400, detail="Invalid image file.") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)