ranchopanda0 commited on
Commit
56dbdf2
·
verified ·
1 Parent(s): 12263fd

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +56 -0
server.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import uvicorn
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image, UnidentifiedImageError
6
+ import torch
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ import json
9
+ import logging
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ # Load Model
14
+ model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
15
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
16
+ model = AutoModelForImageClassification.from_pretrained(model_name)
17
+
18
+ # Load Treatments
19
+ disease_treatments = {}
20
+ try:
21
+ with open("disease_treatments.json", "r") as file:
22
+ disease_treatments = json.load(file)
23
+ except FileNotFoundError:
24
+ logging.warning("Treatment database file not found. Using default treatments.")
25
+ disease_treatments = {"Healthy": "No disease detected. Maintain proper plant care."}
26
+
27
+ app = FastAPI()
28
+
29
+ def predict(image):
30
+ try:
31
+ inputs = processor(images=image, return_tensors="pt")
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ predicted_class_idx = outputs.logits.argmax(-1).item()
35
+ predicted_label = model.config.id2label[predicted_class_idx]
36
+ treatment = disease_treatments.get(predicted_label, "No treatment info available.")
37
+ return {"Disease": predicted_label, "Treatment": treatment}
38
+ except Exception as e:
39
+ logging.error(f"Prediction failed: {str(e)}")
40
+ return {"error": f"Prediction failed: {str(e)}"}
41
+
42
+ @app.post("/predict")
43
+ async def api_predict(file: UploadFile = File(...)):
44
+ try:
45
+ contents = await file.read()
46
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
47
+ image_array = np.array(image)
48
+ prediction = predict(image_array)
49
+ return {"prediction": prediction}
50
+ except UnidentifiedImageError:
51
+ raise HTTPException(status_code=400, detail="Invalid image file.")
52
+ except Exception as e:
53
+ raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ if __name__ == "__main__":
56
+ uvicorn.run(app, host="0.0.0.0", port=7860)