ZhengPeng7 commited on
Commit
5023a18
·
verified ·
1 Parent(s): 7e47111

Fix the bug in identifying the type of image input.

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -102,15 +102,16 @@ def predict(images, resolution, weights_file):
102
  resolution = [1024, 1024]
103
  print('Invalid resolution input. Automatically changed to 1024x1024.')
104
 
105
- print('type(images):', type(images))
106
  if isinstance(images, list):
 
 
107
  save_dir = 'preds-BiRefNet'
108
  if not os.path.exists(save_dir):
109
  os.makedirs(save_dir)
 
110
  else:
111
- # For tab_batch
112
- save_paths = []
113
  images = [images]
 
114
 
115
  for idx_image, image_src in enumerate(images):
116
  if isinstance(image_src, str):
@@ -119,38 +120,38 @@ def predict(images, resolution, weights_file):
119
  image = np.array(Image.open(image_data))
120
  else:
121
  image = image_src
122
-
123
  image_shape = image.shape[:2]
124
  image_pil = array_to_pil_image(image, tuple(resolution))
125
-
126
  # Preprocess the image
127
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
128
  image_proc = image_preprocessor.proc(image_pil)
129
  image_proc = image_proc.unsqueeze(0)
130
-
131
  # Perform the prediction
132
  with torch.no_grad():
133
  scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
134
-
135
  if device == 'cuda':
136
  scaled_pred_tensor = scaled_pred_tensor.cpu()
137
-
138
  # Resize the prediction to match the original image shape
139
  pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
140
-
141
  # Apply the prediction mask to the original image
142
  image_pil = image_pil.resize(pred.shape[::-1])
143
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
144
  image_pred = (pred * np.array(image_pil)).astype(np.uint8)
145
-
146
  torch.cuda.empty_cache()
147
-
148
- if isinstance(images, list):
149
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
150
  cv2.imwrite(save_file_path)
151
  save_paths.append(save_file_path)
152
 
153
- if len(save_paths) > 1:
154
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
155
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
156
  for file in save_paths:
 
102
  resolution = [1024, 1024]
103
  print('Invalid resolution input. Automatically changed to 1024x1024.')
104
 
 
105
  if isinstance(images, list):
106
+ # For tab_batch
107
+ save_paths = []
108
  save_dir = 'preds-BiRefNet'
109
  if not os.path.exists(save_dir):
110
  os.makedirs(save_dir)
111
+ tab_is_batch = True
112
  else:
 
 
113
  images = [images]
114
+ tab_is_batch = False
115
 
116
  for idx_image, image_src in enumerate(images):
117
  if isinstance(image_src, str):
 
120
  image = np.array(Image.open(image_data))
121
  else:
122
  image = image_src
123
+
124
  image_shape = image.shape[:2]
125
  image_pil = array_to_pil_image(image, tuple(resolution))
126
+
127
  # Preprocess the image
128
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
129
  image_proc = image_preprocessor.proc(image_pil)
130
  image_proc = image_proc.unsqueeze(0)
131
+
132
  # Perform the prediction
133
  with torch.no_grad():
134
  scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
135
+
136
  if device == 'cuda':
137
  scaled_pred_tensor = scaled_pred_tensor.cpu()
138
+
139
  # Resize the prediction to match the original image shape
140
  pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
141
+
142
  # Apply the prediction mask to the original image
143
  image_pil = image_pil.resize(pred.shape[::-1])
144
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
145
  image_pred = (pred * np.array(image_pil)).astype(np.uint8)
146
+
147
  torch.cuda.empty_cache()
148
+
149
+ if tab_is_batch:
150
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
151
  cv2.imwrite(save_file_path)
152
  save_paths.append(save_file_path)
153
 
154
+ if tab_is_batch:
155
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
156
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
157
  for file in save_paths: