Spaces:
Running
on
Zero
Running
on
Zero
Fix the bug in identifying the type of image input.
Browse files
app.py
CHANGED
@@ -102,15 +102,16 @@ def predict(images, resolution, weights_file):
|
|
102 |
resolution = [1024, 1024]
|
103 |
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
104 |
|
105 |
-
print('type(images):', type(images))
|
106 |
if isinstance(images, list):
|
|
|
|
|
107 |
save_dir = 'preds-BiRefNet'
|
108 |
if not os.path.exists(save_dir):
|
109 |
os.makedirs(save_dir)
|
|
|
110 |
else:
|
111 |
-
# For tab_batch
|
112 |
-
save_paths = []
|
113 |
images = [images]
|
|
|
114 |
|
115 |
for idx_image, image_src in enumerate(images):
|
116 |
if isinstance(image_src, str):
|
@@ -119,38 +120,38 @@ def predict(images, resolution, weights_file):
|
|
119 |
image = np.array(Image.open(image_data))
|
120 |
else:
|
121 |
image = image_src
|
122 |
-
|
123 |
image_shape = image.shape[:2]
|
124 |
image_pil = array_to_pil_image(image, tuple(resolution))
|
125 |
-
|
126 |
# Preprocess the image
|
127 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
128 |
image_proc = image_preprocessor.proc(image_pil)
|
129 |
image_proc = image_proc.unsqueeze(0)
|
130 |
-
|
131 |
# Perform the prediction
|
132 |
with torch.no_grad():
|
133 |
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
134 |
-
|
135 |
if device == 'cuda':
|
136 |
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
137 |
-
|
138 |
# Resize the prediction to match the original image shape
|
139 |
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
140 |
-
|
141 |
# Apply the prediction mask to the original image
|
142 |
image_pil = image_pil.resize(pred.shape[::-1])
|
143 |
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
144 |
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
145 |
-
|
146 |
torch.cuda.empty_cache()
|
147 |
-
|
148 |
-
if
|
149 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
150 |
cv2.imwrite(save_file_path)
|
151 |
save_paths.append(save_file_path)
|
152 |
|
153 |
-
if
|
154 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
155 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
156 |
for file in save_paths:
|
|
|
102 |
resolution = [1024, 1024]
|
103 |
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
104 |
|
|
|
105 |
if isinstance(images, list):
|
106 |
+
# For tab_batch
|
107 |
+
save_paths = []
|
108 |
save_dir = 'preds-BiRefNet'
|
109 |
if not os.path.exists(save_dir):
|
110 |
os.makedirs(save_dir)
|
111 |
+
tab_is_batch = True
|
112 |
else:
|
|
|
|
|
113 |
images = [images]
|
114 |
+
tab_is_batch = False
|
115 |
|
116 |
for idx_image, image_src in enumerate(images):
|
117 |
if isinstance(image_src, str):
|
|
|
120 |
image = np.array(Image.open(image_data))
|
121 |
else:
|
122 |
image = image_src
|
123 |
+
|
124 |
image_shape = image.shape[:2]
|
125 |
image_pil = array_to_pil_image(image, tuple(resolution))
|
126 |
+
|
127 |
# Preprocess the image
|
128 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
129 |
image_proc = image_preprocessor.proc(image_pil)
|
130 |
image_proc = image_proc.unsqueeze(0)
|
131 |
+
|
132 |
# Perform the prediction
|
133 |
with torch.no_grad():
|
134 |
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
135 |
+
|
136 |
if device == 'cuda':
|
137 |
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
138 |
+
|
139 |
# Resize the prediction to match the original image shape
|
140 |
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
141 |
+
|
142 |
# Apply the prediction mask to the original image
|
143 |
image_pil = image_pil.resize(pred.shape[::-1])
|
144 |
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
145 |
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
146 |
+
|
147 |
torch.cuda.empty_cache()
|
148 |
+
|
149 |
+
if tab_is_batch:
|
150 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
151 |
cv2.imwrite(save_file_path)
|
152 |
save_paths.append(save_file_path)
|
153 |
|
154 |
+
if tab_is_batch:
|
155 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
156 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
157 |
for file in save_paths:
|