jays009 commited on
Commit
7ef3e33
·
verified ·
1 Parent(s): 0c47ae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -44,21 +44,31 @@ transform = transforms.Compose([
44
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
  ])
46
 
47
- def predict_from_image(image):
48
  try:
49
- # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
50
- if not isinstance(image, Image.Image):
 
 
 
 
 
 
 
 
 
 
51
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
52
 
53
  # Apply transformations to the image
54
- image = transform(image).unsqueeze(0)
55
 
56
  # Move the image to the correct device
57
- image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
58
 
59
  # Make predictions
60
  with torch.no_grad():
61
- outputs = model(image)
62
  predicted_class = torch.argmax(outputs, dim=1).item()
63
 
64
  # Return the result based on the predicted class
@@ -73,23 +83,9 @@ def predict_from_image(image):
73
  return json.dumps({"error": f"Error processing image: {e}"})
74
 
75
 
76
- def predict_from_url(url):
77
- try:
78
- # Check if the URL is valid and try fetching the image
79
- response = requests.get(url)
80
- if response.status_code == 200:
81
- img = Image.open(BytesIO(response.content))
82
- # Call the predict function for the image
83
- return predict_from_image(img)
84
- else:
85
- return json.dumps({"error": "Unable to fetch image from the URL."})
86
- except Exception as e:
87
- return json.dumps({"error": f"Error fetching image from URL: {e}"})
88
-
89
-
90
  # Create the Gradio interface with both local file upload and URL input
91
  iface = gr.Interface(
92
- fn=predict_from_image,
93
  inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"),
94
  gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
95
  outputs=gr.Textbox(label="Prediction Result"),
 
44
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
  ])
46
 
47
+ def predict(input_data):
48
  try:
49
+ # Check if the input is a URL or image (Gradio will pass either based on user input)
50
+ if isinstance(input_data, str): # If it's a string, assume it's a URL
51
+ response = requests.get(input_data)
52
+ if response.status_code == 200:
53
+ img = Image.open(BytesIO(response.content))
54
+ else:
55
+ return json.dumps({"error": "Unable to fetch image from the URL."})
56
+ else: # If it's not a string, assume it's an image file
57
+ img = input_data
58
+
59
+ # Check if the input is a PIL Image type
60
+ if not isinstance(img, Image.Image):
61
  return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
62
 
63
  # Apply transformations to the image
64
+ img = transform(img).unsqueeze(0)
65
 
66
  # Move the image to the correct device
67
+ img = img.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
68
 
69
  # Make predictions
70
  with torch.no_grad():
71
+ outputs = model(img)
72
  predicted_class = torch.argmax(outputs, dim=1).item()
73
 
74
  # Return the result based on the predicted class
 
83
  return json.dumps({"error": f"Error processing image: {e}"})
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Create the Gradio interface with both local file upload and URL input
87
  iface = gr.Interface(
88
+ fn=predict,
89
  inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"),
90
  gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
91
  outputs=gr.Textbox(label="Prediction Result"),