jays009 commited on
Commit
52fd9c2
·
verified ·
1 Parent(s): fb8a03b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -37
app.py CHANGED
@@ -6,9 +6,9 @@ from torchvision import models, transforms
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
 
9
  import base64
10
  from io import BytesIO
11
- import os
12
 
13
  # Define the number of classes
14
  num_classes = 2
@@ -46,46 +46,22 @@ transform = transforms.Compose([
46
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
  ])
48
 
49
- def predict(data):
50
  try:
51
- # Check if the data is a list and not empty
52
- if not isinstance(data, list) or len(data) == 0:
53
- return json.dumps({"error": "Input data should be a non-empty list."})
54
-
55
- # Extract the image path
56
- image_input = data[0].get('path', None)
57
- if not image_input:
58
- return json.dumps({"error": "No image path provided."})
59
 
60
- print(f"Received image input: {image_input}")
61
-
62
- # Handle URLs
63
- if isinstance(image_input, str) and (image_input.startswith("http://") or image_input.startswith("https://")):
64
- try:
65
- response = requests.get(image_input)
66
- response.raise_for_status() # Check for HTTP errors
67
- image = Image.open(BytesIO(response.content))
68
- print(f"Fetched image from URL: {image}")
69
- except Exception as e:
70
- print(f"Error fetching image from URL: {e}")
71
- return json.dumps({"error": f"Error fetching image from URL: {e}"})
72
-
73
- # Check if the image path is a valid local path
74
- elif isinstance(image_input, str) and os.path.exists(image_input):
75
- try:
76
- image = Image.open(image_input)
77
- print(f"Loaded image from local path: {image}")
78
- except Exception as e:
79
- return json.dumps({"error": f"Error loading image from local path: {e}"})
80
-
81
- else:
82
- return json.dumps({"error": "Invalid image path. Ensure it's a valid URL or local path."})
83
 
84
- # Apply the transformations and make prediction
85
  image = transform(image).unsqueeze(0)
86
  print(f"Transformed image tensor: {image.shape}")
 
 
87
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
88
 
 
89
  with torch.no_grad():
90
  outputs = model(image)
91
  predicted_class = torch.argmax(outputs, dim=1).item()
@@ -98,7 +74,7 @@ def predict(data):
98
  return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
99
  else:
100
  return json.dumps({"error": "Unexpected class prediction."})
101
-
102
  except Exception as e:
103
  print(f"Error processing image: {e}")
104
  return json.dumps({"error": f"Error processing image: {e}"})
@@ -106,7 +82,7 @@ def predict(data):
106
  # Create the Gradio interface
107
  iface = gr.Interface(
108
  fn=predict,
109
- inputs=gr.JSON(label="Input JSON"),
110
  outputs=gr.Textbox(label="Prediction Result"),
111
  live=True,
112
  title="Maize Anomaly Detection",
@@ -115,4 +91,3 @@ iface = gr.Interface(
115
 
116
  # Launch the Gradio interface
117
  iface.launch(share=True, show_error=True)
118
-
 
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
+ import os
10
  import base64
11
  from io import BytesIO
 
12
 
13
  # Define the number of classes
14
  num_classes = 2
 
46
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
  ])
48
 
49
+ def predict(image):
50
  try:
51
+ print(f"Received image input: {image}")
 
 
 
 
 
 
 
52
 
53
+ # Check if the input is a PIL Image type (Gradio already provides a PIL image)
54
+ if not isinstance(image, Image.Image):
55
+ return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Apply transformations to the image
58
  image = transform(image).unsqueeze(0)
59
  print(f"Transformed image tensor: {image.shape}")
60
+
61
+ # Move the image to the correct device
62
  image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
63
 
64
+ # Make predictions
65
  with torch.no_grad():
66
  outputs = model(image)
67
  predicted_class = torch.argmax(outputs, dim=1).item()
 
74
  return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
75
  else:
76
  return json.dumps({"error": "Unexpected class prediction."})
77
+
78
  except Exception as e:
79
  print(f"Error processing image: {e}")
80
  return json.dumps({"error": f"Error processing image: {e}"})
 
82
  # Create the Gradio interface
83
  iface = gr.Interface(
84
  fn=predict,
85
+ inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"),
86
  outputs=gr.Textbox(label="Prediction Result"),
87
  live=True,
88
  title="Maize Anomaly Detection",
 
91
 
92
  # Launch the Gradio interface
93
  iface.launch(share=True, show_error=True)