jays009 commited on
Commit
4869d07
·
verified ·
1 Parent(s): 5479ea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -6,8 +6,6 @@ from torchvision import models, transforms
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
- import os
10
- import base64
11
  from io import BytesIO
12
 
13
  # Define the number of classes
@@ -46,17 +44,14 @@ transform = transforms.Compose([
46
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
  ])
48
 
49
- def predict(image):
50
  try:
51
- print(f"Received image input: {image}")
52
-
53
  # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
54
  if not isinstance(image, Image.Image):
55
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
56
 
57
  # Apply transformations to the image
58
  image = transform(image).unsqueeze(0)
59
- print(f"Transformed image tensor: {image.shape}")
60
 
61
  # Move the image to the correct device
62
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
@@ -65,7 +60,6 @@ def predict(image):
65
  with torch.no_grad():
66
  outputs = model(image)
67
  predicted_class = torch.argmax(outputs, dim=1).item()
68
- print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
69
 
70
  # Return the result based on the predicted class
71
  if predicted_class == 0:
@@ -76,18 +70,43 @@ def predict(image):
76
  return json.dumps({"error": "Unexpected class prediction."})
77
 
78
  except Exception as e:
79
- print(f"Error processing image: {e}")
80
  return json.dumps({"error": f"Error processing image: {e}"})
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Create the Gradio interface
83
  iface = gr.Interface(
84
- fn=predict,
85
- inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"),
86
  outputs=gr.Textbox(label="Prediction Result"),
87
  live=True,
88
  title="Maize Anomaly Detection",
89
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
90
  )
91
 
92
- # Launch the Gradio interface
 
 
 
 
 
 
 
 
 
 
93
  iface.launch(share=True, show_error=True)
 
 
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
 
 
9
  from io import BytesIO
10
 
11
  # Define the number of classes
 
44
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
  ])
46
 
47
+ def predict_from_image(image):
48
  try:
 
 
49
  # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
50
  if not isinstance(image, Image.Image):
51
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
52
 
53
  # Apply transformations to the image
54
  image = transform(image).unsqueeze(0)
 
55
 
56
  # Move the image to the correct device
57
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
60
  with torch.no_grad():
61
  outputs = model(image)
62
  predicted_class = torch.argmax(outputs, dim=1).item()
 
63
 
64
  # Return the result based on the predicted class
65
  if predicted_class == 0:
 
70
  return json.dumps({"error": "Unexpected class prediction."})
71
 
72
  except Exception as e:
 
73
  return json.dumps({"error": f"Error processing image: {e}"})
74
 
75
+
76
+ def predict_from_url(url):
77
+ try:
78
+ # Check if the URL is valid and try fetching the image
79
+ response = requests.get(url)
80
+ if response.status_code == 200:
81
+ img = Image.open(BytesIO(response.content))
82
+ # Call the predict function for the image
83
+ return predict_from_image(img)
84
+ else:
85
+ return json.dumps({"error": "Unable to fetch image from the URL."})
86
+ except Exception as e:
87
+ return json.dumps({"error": f"Error fetching image from URL: {e}"})
88
+
89
+
90
  # Create the Gradio interface
91
  iface = gr.Interface(
92
+ fn=predict_from_image,
93
+ inputs=gr.Image(type="pil", label="Upload an image or provide a local path"),
94
  outputs=gr.Textbox(label="Prediction Result"),
95
  live=True,
96
  title="Maize Anomaly Detection",
97
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
98
  )
99
 
100
+ # Add another function for URL input
101
+ url_iface = gr.Interface(
102
+ fn=predict_from_url,
103
+ inputs=gr.Textbox(label="Enter image URL"),
104
+ outputs=gr.Textbox(label="Prediction Result from URL"),
105
+ live=True,
106
+ title="Maize Anomaly Detection from URL",
107
+ description="Provide an image URL to detect anomalies like disease or pest infestation."
108
+ )
109
+
110
+ # Launch the Gradio interface
111
  iface.launch(share=True, show_error=True)
112
+ url_iface.launch(share=True, show_error=True)