ZhengPeng7 commited on
Commit
a3743df
·
verified ·
1 Parent(s): 16e7ea5

Fix a bug in loading the input in tab_batch.

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -119,13 +119,13 @@ def predict(images, resolution, weights_file):
119
  # Apply the prediction mask to the original image
120
  image_pil = image_pil.resize(pred.shape[::-1])
121
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
122
- image_pred = (pred * np.array(image_pil)).astype(np.uint8)
123
 
124
  torch.cuda.empty_cache()
125
 
126
  if tab_is_batch:
127
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
128
- cv2.imwrite(save_file_path)
129
  save_paths.append(save_file_path)
130
 
131
  if tab_is_batch:
@@ -134,7 +134,7 @@ def predict(images, resolution, weights_file):
134
  for file in save_paths:
135
  zipf.write(file, os.path.basename(file))
136
 
137
- return image, image_pred
138
 
139
 
140
  examples = [[_] for _ in glob('examples/*')][:]
 
119
  # Apply the prediction mask to the original image
120
  image_pil = image_pil.resize(pred.shape[::-1])
121
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
122
+ image_masked = (pred * np.array(image_pil)).astype(np.uint8)
123
 
124
  torch.cuda.empty_cache()
125
 
126
  if tab_is_batch:
127
  save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
128
+ cv2.imwrite(save_file_path, image_masked)
129
  save_paths.append(save_file_path)
130
 
131
  if tab_is_batch:
 
134
  for file in save_paths:
135
  zipf.write(file, os.path.basename(file))
136
 
137
+ return image, image_masked
138
 
139
 
140
  examples = [[_] for _ in glob('examples/*')][:]