Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, HTTPException | |
import uvicorn | |
import io | |
import numpy as np | |
from PIL import Image, UnidentifiedImageError | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
import json | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
# Load Model | |
model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" | |
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Load Treatments | |
disease_treatments = {} | |
try: | |
with open("disease_treatments.json", "r") as file: | |
disease_treatments = json.load(file) | |
except FileNotFoundError: | |
logging.warning("Treatment database file not found. Using default treatments.") | |
disease_treatments = {"Healthy": "No disease detected. Maintain proper plant care."} | |
app = FastAPI() | |
def predict(image): | |
try: | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class_idx = outputs.logits.argmax(-1).item() | |
predicted_label = model.config.id2label[predicted_class_idx] | |
treatment = disease_treatments.get(predicted_label, "No treatment info available.") | |
return {"Disease": predicted_label, "Treatment": treatment} | |
except Exception as e: | |
logging.error(f"Prediction failed: {str(e)}") | |
return {"error": f"Prediction failed: {str(e)}"} | |
async def api_predict(file: UploadFile = File(...)): | |
try: | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
image_array = np.array(image) | |
prediction = predict(image_array) | |
return {"prediction": prediction} | |
except UnidentifiedImageError: | |
raise HTTPException(status_code=400, detail="Invalid image file.") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |