jays009 commited on
Commit
fc29cbf
·
verified ·
1 Parent(s): 01a4ed7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -55
app.py CHANGED
@@ -45,70 +45,53 @@ transform = transforms.Compose([
45
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
  ])
47
 
48
- def predict(image):
49
  try:
50
- print(f"Received image input: {image}")
51
-
52
- # Check if the input contains a base64-encoded string
53
- if isinstance(image, dict) and image.get("data"):
54
- try:
55
- image_data = base64.b64decode(image["data"])
56
- image = Image.open(BytesIO(image_data))
57
- print(f"Decoded base64 image: {image}")
58
- except Exception as e:
59
- print(f"Error decoding base64 image: {e}")
60
- return f"Error decoding base64 image: {e}"
61
-
62
- # Check if the input is a URL
63
- elif isinstance(image, str) and image.startswith("http"):
64
- try:
65
- response = requests.get(image)
66
  image = Image.open(BytesIO(response.content))
67
- print(f"Fetched image from URL: {image}")
68
- except Exception as e:
69
- print(f"Error fetching image from URL: {e}")
70
- return f"Error fetching image from URL: {e}"
71
-
72
- # Check if the input is a local file path
73
- elif isinstance(image, str) and os.path.isfile(image):
74
- try:
75
- image = Image.open(image)
76
- print(f"Loaded image from local path: {image}")
77
- except Exception as e:
78
- print(f"Error loading image from local path: {e}")
79
- return f"Error loading image from local path: {e}"
80
-
81
- # Validate that the image is correctly loaded
82
- if not isinstance(image, Image.Image):
83
- print("Invalid image format received.")
84
- return "Invalid image format received."
85
-
86
- # Apply transformations
87
- image = transform(image).unsqueeze(0)
88
- print(f"Transformed image tensor: {image.shape}")
89
-
90
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
91
-
92
- with torch.no_grad():
93
- outputs = model(image)
94
- predicted_class = torch.argmax(outputs, dim=1).item()
95
- print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
96
-
97
- if predicted_class == 0:
98
- return "The photo you've sent is of fall army worm with problem ID 126."
99
- elif predicted_class == 1:
100
- return "The photo you've sent is of a healthy maize image."
101
  else:
102
- return "Unexpected class prediction."
103
  except Exception as e:
104
  print(f"Error processing image: {e}")
105
  return f"Error processing image: {e}"
106
 
107
  # Create the Gradio interface
108
  iface = gr.Interface(
109
- fn=predict,
110
- inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"), # Input: Image, URL, or Local Path
111
- outputs=gr.Textbox(label="Prediction Result"), # Output: Predicted class
112
  live=True,
113
  title="Maize Anomaly Detection",
114
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
 
45
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
  ])
47
 
48
+ def process_image(image_input):
49
  try:
50
+ # Process the image input (URL, local file, or base64)
51
+ if isinstance(image_input, dict):
52
+ # Check if the input contains a URL
53
+ if image_input.get("url"):
54
+ image_url = image_input["url"]
55
+ response = requests.get(image_url)
 
 
 
 
 
 
 
 
 
 
56
  image = Image.open(BytesIO(response.content))
57
+ # Check if the input contains a file path
58
+ elif image_input.get("path"):
59
+ image_path = image_input["path"]
60
+ image = Image.open(image_path)
61
+ # Handle base64 if it's included
62
+ elif image_input.get("data"):
63
+ image_data = base64.b64decode(image_input["data"])
64
+ image = Image.open(BytesIO(image_data))
65
+ else:
66
+ return "Invalid input data format. Please provide a URL or path."
67
+
68
+ # Apply transformations
69
+ image = transform(image).unsqueeze(0)
70
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
71
+
72
+ # Make the prediction
73
+ with torch.no_grad():
74
+ outputs = model(image)
75
+ predicted_class = torch.argmax(outputs, dim=1).item()
76
+
77
+ # Return prediction result
78
+ if predicted_class == 0:
79
+ return "The photo you've sent is of fall army worm with problem ID 126."
80
+ elif predicted_class == 1:
81
+ return "The photo you've sent is of a healthy maize image."
82
+ else:
83
+ return "Unexpected class prediction."
 
 
 
 
 
 
 
84
  else:
85
+ return "Invalid input. Please provide a dictionary with 'url' or 'path'."
86
  except Exception as e:
87
  print(f"Error processing image: {e}")
88
  return f"Error processing image: {e}"
89
 
90
  # Create the Gradio interface
91
  iface = gr.Interface(
92
+ fn=process_image,
93
+ inputs=gr.JSON(label="Upload an image (URL or Local Path)"), # Input: JSON to handle URL or path
94
+ outputs=gr.Textbox(label="Prediction Result"), # Output: Prediction result
95
  live=True,
96
  title="Maize Anomaly Detection",
97
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."