Erick Garcia Espinosa commited on
Commit
b009213
·
1 Parent(s): 4d70acc

improvements

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -16,8 +16,11 @@ transform = transforms.Compose([
16
  ])
17
 
18
  # Function to load and preprocess an image
19
- def load_image(image_path):
20
- image = Image.open(image_path).convert('RGB')
 
 
 
21
  image = transform(image).unsqueeze(0) # Add batch dimension
22
  return image
23
 
@@ -34,12 +37,8 @@ def predict_image(image):
34
  if image is None:
35
  return "No image provided", None
36
 
37
- # Convert the image to PIL if it's not a filepath
38
- if isinstance(image, np.ndarray):
39
- image = Image.fromarray(image.astype('uint8'), 'RGB')
40
-
41
- # Load and transform the image from the file path
42
- img_tensor = transform(image).unsqueeze(0)
43
 
44
  # Perform the prediction
45
  with torch.no_grad():
@@ -72,7 +71,7 @@ def predict_image(image):
72
  iface = gr.Interface(
73
  fn=predict_image,
74
  inputs=[
75
- gr.Image(source="upload", type="pil", tool="editor", label="Upload an image or take a photo"),
76
  gr.Image(source="webcam", type="pil", tool="editor", label="Take a photo")
77
  ],
78
  outputs=[
 
16
  ])
17
 
18
  # Function to load and preprocess an image
19
+ def load_image(image):
20
+ if isinstance(image, str): # If image is a file path
21
+ image = Image.open(image).convert('RGB')
22
+ elif isinstance(image, np.ndarray): # If image is a numpy array
23
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
24
  image = transform(image).unsqueeze(0) # Add batch dimension
25
  return image
26
 
 
37
  if image is None:
38
  return "No image provided", None
39
 
40
+ # Load and transform the image
41
+ img_tensor = load_image(image)
 
 
 
 
42
 
43
  # Perform the prediction
44
  with torch.no_grad():
 
71
  iface = gr.Interface(
72
  fn=predict_image,
73
  inputs=[
74
+ gr.Image(type="pil", tool="editor", label="Upload an image or take a photo"),
75
  gr.Image(source="webcam", type="pil", tool="editor", label="Take a photo")
76
  ],
77
  outputs=[