jays009 commited on
Commit
0c6fbcb
·
verified ·
1 Parent(s): 2ba8106

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -34
app.py CHANGED
@@ -6,6 +6,8 @@ from torchvision import models, transforms
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
 
 
9
  from io import BytesIO
10
 
11
  # Define the number of classes
@@ -44,46 +46,26 @@ transform = transforms.Compose([
44
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
  ])
46
 
47
- def predict(input_data):
48
  try:
49
- print(f"Input data received: {input_data}, Type: {type(input_data)}")
50
-
51
- # Check if the input is a URL or image
52
- if isinstance(input_data, str): # If it's a string, assume it's a URL
53
- try:
54
- response = requests.get(input_data)
55
- response.raise_for_status() # Raise error for HTTP issues
56
- img = Image.open(BytesIO(response.content))
57
- print("Image fetched successfully from URL.")
58
- except Exception as e:
59
- print(f"Error fetching image from URL: {e}")
60
- return json.dumps({"error": f"Failed to fetch image from URL: {e}"})
61
- else: # If it's not a string, assume it's an image file
62
- img = input_data
63
 
64
- # Validate the image
65
- if not isinstance(img, Image.Image):
66
- print("Invalid image format received.")
67
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
68
- else:
69
- print(f"Image successfully loaded: {img}")
70
 
71
  # Apply transformations to the image
72
- img = transform(img).unsqueeze(0)
73
- print(f"Transformed image tensor shape: {img.shape}")
74
-
75
- # Ensure model is loaded
76
- if model is None:
77
- return json.dumps({"error": "Model not loaded. Ensure the model file is available and correctly loaded."})
78
 
79
  # Move the image to the correct device
80
- img = img.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
81
 
82
  # Make predictions
83
  with torch.no_grad():
84
- outputs = model(img)
85
  predicted_class = torch.argmax(outputs, dim=1).item()
86
- print(f"Model prediction outputs: {outputs}, Predicted class: {predicted_class}")
87
 
88
  # Return the result based on the predicted class
89
  if predicted_class == 0:
@@ -97,17 +79,15 @@ def predict(input_data):
97
  print(f"Error processing image: {e}")
98
  return json.dumps({"error": f"Error processing image: {e}"})
99
 
100
-
101
- # Create the Gradio interface with both local file upload and URL input
102
  iface = gr.Interface(
103
  fn=predict,
104
- inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"),
105
- gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
106
  outputs=gr.Textbox(label="Prediction Result"),
107
  live=True,
108
  title="Maize Anomaly Detection",
109
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
110
  )
111
 
112
- # Launch the Gradio interface
113
  iface.launch(share=True, show_error=True)
 
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import requests
9
+ import os
10
+ import base64
11
  from io import BytesIO
12
 
13
  # Define the number of classes
 
46
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
47
  ])
48
 
49
+ def predict(image):
50
  try:
51
+ print(f"Received image input: {image}")
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
54
+ if not isinstance(image, Image.Image):
 
55
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
 
 
56
 
57
  # Apply transformations to the image
58
+ image = transform(image).unsqueeze(0)
59
+ print(f"Transformed image tensor: {image.shape}")
 
 
 
 
60
 
61
  # Move the image to the correct device
62
+ image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
63
 
64
  # Make predictions
65
  with torch.no_grad():
66
+ outputs = model(image)
67
  predicted_class = torch.argmax(outputs, dim=1).item()
68
+ print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
69
 
70
  # Return the result based on the predicted class
71
  if predicted_class == 0:
 
79
  print(f"Error processing image: {e}")
80
  return json.dumps({"error": f"Error processing image: {e}"})
81
 
82
+ # Create the Gradio interface
 
83
  iface = gr.Interface(
84
  fn=predict,
85
+ inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"),
 
86
  outputs=gr.Textbox(label="Prediction Result"),
87
  live=True,
88
  title="Maize Anomaly Detection",
89
  description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
90
  )
91
 
92
+ # Launch the Gradio interface
93
  iface.launch(share=True, show_error=True)