ranchopanda0 commited on
Commit
be7dd52
·
verified ·
1 Parent(s): 9410e4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -3
app.py CHANGED
@@ -1,13 +1,111 @@
1
  import gradio as gr
2
- from server import predict
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  iface = gr.Interface(
5
  fn=lambda img: (lambda pred: f"Predicted Disease: {pred.get('Disease', 'Error')}\nRecommended Treatment: {pred.get('Treatment', 'N/A')}")(predict(img)),
6
  inputs=gr.Image(type="numpy"),
7
  outputs="text",
8
  title="🌿 Plant Disease Detector",
9
- description="Upload an image to identify plant diseases.",
10
  )
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if __name__ == "__main__":
13
- iface.launch(server_name="0.0.0.0", server_port=7861, share=True)
 
 
1
  import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image, UnidentifiedImageError
4
+ import torch
5
+ import numpy as np
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException
7
+ import uvicorn
8
+ import io
9
+ import json
10
+ import threading
11
+ import os
12
+ import logging
13
+ import signal
14
+ import sys
15
 
16
+ # Configure Logging
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ # Load Model & Processor
20
+ model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
21
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=False) # Avoid fast mode issue
22
+ model = AutoModelForImageClassification.from_pretrained(model_name)
23
+
24
+ # FastAPI Setup
25
+ app = FastAPI()
26
+
27
+ # Load Disease Treatment Database Dynamically
28
+ disease_treatments = {}
29
+ try:
30
+ with open("disease_treatments.json", "r") as file:
31
+ disease_treatments = json.load(file)
32
+ except FileNotFoundError:
33
+ logging.warning("Treatment database file not found. Using default treatments.")
34
+ disease_treatments = {
35
+ "Powdery Mildew": "Use fungicides like sulfur or neem oil. Improve air circulation.",
36
+ "Leaf Blight": "Apply copper-based fungicides and ensure proper plant spacing.",
37
+ "Rust": "Use resistant varieties and apply organic sulfur fungicide.",
38
+ "Healthy": "No disease detected! Keep maintaining proper watering and soil health.",
39
+ }
40
+
41
+ # Input Validation for Image Size
42
+ def validate_image(image):
43
+ if image.size[0] < 64 or image.size[1] < 64:
44
+ raise ValueError("Image is too small. Please upload an image with dimensions at least 64x64.")
45
+ return image
46
+
47
+ # Define Prediction Function
48
+ def predict(image):
49
+ try:
50
+ image = validate_image(Image.fromarray(np.uint8(image)).convert("RGB"))
51
+ inputs = processor(images=image, return_tensors="pt")
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+ logits = outputs.logits
55
+ predicted_class_idx = logits.argmax(-1).item()
56
+ predicted_label = model.config.id2label[predicted_class_idx]
57
+ treatment = disease_treatments.get(predicted_label, "No treatment information available.")
58
+ return {"Disease": predicted_label, "Treatment": treatment}
59
+ except Exception as e:
60
+ logging.error(f"Prediction failed: {str(e)}")
61
+ return {"error": f"Prediction failed: {str(e)}"}
62
+
63
+ # Gradio Interface
64
  iface = gr.Interface(
65
  fn=lambda img: (lambda pred: f"Predicted Disease: {pred.get('Disease', 'Error')}\nRecommended Treatment: {pred.get('Treatment', 'N/A')}")(predict(img)),
66
  inputs=gr.Image(type="numpy"),
67
  outputs="text",
68
  title="🌿 Plant Disease Detector",
69
+ description="Upload an image or take a photo to identify plant diseases.",
70
  )
71
 
72
+ # FastAPI Endpoint
73
+ @app.post("/predict")
74
+ async def api_predict(file: UploadFile = File(...)):
75
+ logging.info(f"Received file: {file.filename}")
76
+ try:
77
+ contents = await file.read()
78
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
79
+ image_array = np.array(image)
80
+ prediction = predict(image_array)
81
+ logging.info(f"Prediction successful: {prediction}")
82
+ return {"prediction": prediction}
83
+ except UnidentifiedImageError:
84
+ logging.error("Invalid image file uploaded.")
85
+ raise HTTPException(status_code=400, detail="Invalid image file. Please upload a valid image.")
86
+ except ValueError as ve:
87
+ logging.error(f"ValueError during prediction: {str(ve)}")
88
+ raise HTTPException(status_code=400, detail=f"Invalid input: {str(ve)}")
89
+ except Exception as e:
90
+ logging.error(f"Internal server error: {str(e)}")
91
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
92
+
93
+ # Configurable Ports
94
+ FASTAPI_PORT = int(os.getenv("FASTAPI_PORT", 7860))
95
+ GRADIO_PORT = int(os.getenv("GRADIO_PORT", 7862)) # Changed from 7861 to 7862
96
+
97
+ # Graceful Shutdown
98
+ def shutdown(signum, frame):
99
+ logging.info("Shutting down servers...")
100
+ sys.exit(0)
101
+
102
+ signal.signal(signal.SIGINT, shutdown)
103
+ signal.signal(signal.SIGTERM, shutdown)
104
+
105
+ # Run FastAPI & Gradio Together
106
+ def run_fastapi():
107
+ uvicorn.run(app, host="0.0.0.0", port=FASTAPI_PORT)
108
+
109
  if __name__ == "__main__":
110
+ threading.Thread(target=run_fastapi, daemon=True).start()
111
+ iface.launch(server_name="0.0.0.0", server_port=GRADIO_PORT, share=True)