ranchopanda0 commited on
Commit
4bb0527
·
verified ·
1 Parent(s): be7dd52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -77
app.py CHANGED
@@ -1,111 +1,80 @@
1
  import gradio as gr
2
  from transformers import AutoImageProcessor, AutoModelForImageClassification
3
- from PIL import Image, UnidentifiedImageError
4
  import torch
5
  import numpy as np
6
- from fastapi import FastAPI, UploadFile, File, HTTPException
7
- import uvicorn
8
- import io
9
  import json
10
- import threading
11
- import os
12
  import logging
13
- import signal
14
- import sys
15
 
16
  # Configure Logging
17
- logging.basicConfig(level=logging.INFO)
18
 
19
  # Load Model & Processor
20
  model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
21
- processor = AutoImageProcessor.from_pretrained(model_name, use_fast=False) # Avoid fast mode issue
22
- model = AutoModelForImageClassification.from_pretrained(model_name)
 
 
 
 
 
23
 
24
- # FastAPI Setup
25
- app = FastAPI()
26
 
27
- # Load Disease Treatment Database Dynamically
28
- disease_treatments = {}
29
- try:
30
- with open("disease_treatments.json", "r") as file:
31
- disease_treatments = json.load(file)
32
- except FileNotFoundError:
33
- logging.warning("Treatment database file not found. Using default treatments.")
34
- disease_treatments = {
35
- "Powdery Mildew": "Use fungicides like sulfur or neem oil. Improve air circulation.",
36
- "Leaf Blight": "Apply copper-based fungicides and ensure proper plant spacing.",
37
- "Rust": "Use resistant varieties and apply organic sulfur fungicide.",
38
- "Healthy": "No disease detected! Keep maintaining proper watering and soil health.",
39
  }
 
 
 
 
 
 
 
 
40
 
41
- # Input Validation for Image Size
42
  def validate_image(image):
43
  if image.size[0] < 64 or image.size[1] < 64:
44
- raise ValueError("Image is too small. Please upload an image with dimensions at least 64x64.")
45
  return image
46
 
47
- # Define Prediction Function
48
  def predict(image):
49
  try:
50
- image = validate_image(Image.fromarray(np.uint8(image)).convert("RGB"))
 
 
51
  inputs = processor(images=image, return_tensors="pt")
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
  logits = outputs.logits
55
  predicted_class_idx = logits.argmax(-1).item()
56
  predicted_label = model.config.id2label[predicted_class_idx]
57
- treatment = disease_treatments.get(predicted_label, "No treatment information available.")
58
- return {"Disease": predicted_label, "Treatment": treatment}
 
 
59
  except Exception as e:
60
  logging.error(f"Prediction failed: {str(e)}")
61
- return {"error": f"Prediction failed: {str(e)}"}
62
 
63
  # Gradio Interface
64
  iface = gr.Interface(
65
- fn=lambda img: (lambda pred: f"Predicted Disease: {pred.get('Disease', 'Error')}\nRecommended Treatment: {pred.get('Treatment', 'N/A')}")(predict(img)),
66
- inputs=gr.Image(type="numpy"),
67
- outputs="text",
68
- title="🌿 Plant Disease Detector",
69
- description="Upload an image or take a photo to identify plant diseases.",
 
70
  )
71
 
72
- # FastAPI Endpoint
73
- @app.post("/predict")
74
- async def api_predict(file: UploadFile = File(...)):
75
- logging.info(f"Received file: {file.filename}")
76
- try:
77
- contents = await file.read()
78
- image = Image.open(io.BytesIO(contents)).convert("RGB")
79
- image_array = np.array(image)
80
- prediction = predict(image_array)
81
- logging.info(f"Prediction successful: {prediction}")
82
- return {"prediction": prediction}
83
- except UnidentifiedImageError:
84
- logging.error("Invalid image file uploaded.")
85
- raise HTTPException(status_code=400, detail="Invalid image file. Please upload a valid image.")
86
- except ValueError as ve:
87
- logging.error(f"ValueError during prediction: {str(ve)}")
88
- raise HTTPException(status_code=400, detail=f"Invalid input: {str(ve)}")
89
- except Exception as e:
90
- logging.error(f"Internal server error: {str(e)}")
91
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
92
-
93
- # Configurable Ports
94
- FASTAPI_PORT = int(os.getenv("FASTAPI_PORT", 7860))
95
- GRADIO_PORT = int(os.getenv("GRADIO_PORT", 7862)) # Changed from 7861 to 7862
96
-
97
- # Graceful Shutdown
98
- def shutdown(signum, frame):
99
- logging.info("Shutting down servers...")
100
- sys.exit(0)
101
-
102
- signal.signal(signal.SIGINT, shutdown)
103
- signal.signal(signal.SIGTERM, shutdown)
104
-
105
- # Run FastAPI & Gradio Together
106
- def run_fastapi():
107
- uvicorn.run(app, host="0.0.0.0", port=FASTAPI_PORT)
108
-
109
- if __name__ == "__main__":
110
- threading.Thread(target=run_fastapi, daemon=True).start()
111
- iface.launch(server_name="0.0.0.0", server_port=GRADIO_PORT, share=True)
 
1
  import gradio as gr
2
  from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
  import torch
5
  import numpy as np
 
 
 
6
  import json
 
 
7
  import logging
8
+ import os
 
9
 
10
  # Configure Logging
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
 
13
  # Load Model & Processor
14
  model_name = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
15
+ try:
16
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
17
+ model = AutoModelForImageClassification.from_pretrained(model_name)
18
+ logging.info("✅ Model and processor loaded successfully.")
19
+ except Exception as e:
20
+ logging.error(f"❌ Failed to load model: {str(e)}")
21
+ raise RuntimeError("Failed to load the model. Please check the logs for details.")
22
 
23
+ # Load or Create Disease Treatment Database
24
+ disease_treatments_file = "disease_treatments.json"
25
 
26
+ if not os.path.exists(disease_treatments_file):
27
+ logging.warning("⚠️ Treatment database file not found. Creating a default one.")
28
+ default_treatments = {
29
+ "Powdery Mildew": "Use fungicides like sulfur or neem oil.",
30
+ "Leaf Blight": "Apply copper-based fungicides.",
31
+ "Rust": "Use resistant varieties.",
32
+ "Healthy": "No disease detected!",
 
 
 
 
 
33
  }
34
+ with open(disease_treatments_file, "w") as file:
35
+ json.dump(default_treatments, file)
36
+ logging.info("✅ Created default 'disease_treatments.json' file.")
37
+
38
+ # Load Treatments
39
+ with open(disease_treatments_file, "r") as file:
40
+ disease_treatments = json.load(file)
41
+ logging.info("✅ Treatment database loaded successfully.")
42
 
43
+ # Image Validation
44
  def validate_image(image):
45
  if image.size[0] < 64 or image.size[1] < 64:
46
+ raise ValueError("⚠️ Image is too small. Please upload an image of at least 64x64 pixels.")
47
  return image
48
 
49
+ # Prediction Function
50
  def predict(image):
51
  try:
52
+ image = Image.fromarray(np.uint8(image)).convert("RGB")
53
+ validate_image(image)
54
+
55
  inputs = processor(images=image, return_tensors="pt")
56
  with torch.no_grad():
57
  outputs = model(**inputs)
58
  logits = outputs.logits
59
  predicted_class_idx = logits.argmax(-1).item()
60
  predicted_label = model.config.id2label[predicted_class_idx]
61
+
62
+ treatment = disease_treatments.get(predicted_label, "No treatment information available for this disease.")
63
+ return f"Predicted Disease: {predicted_label}\nTreatment: {treatment}"
64
+
65
  except Exception as e:
66
  logging.error(f"Prediction failed: {str(e)}")
67
+ return f" Prediction failed: {str(e)}"
68
 
69
  # Gradio Interface
70
  iface = gr.Interface(
71
+ fn=predict,
72
+ inputs=gr.Image(type="numpy", label="Upload or capture plant image"),
73
+ outputs=gr.Textbox(label="Result"),
74
+ title="Plant Disease Detector",
75
+ description="Upload a plant leaf image to detect diseases and get treatment suggestions.",
76
+ allow_flagging="never",
77
  )
78
 
79
+ # Launch Gradio App
80
+ iface.launch(share=True) # No fixed port to avoid conflicts