gokaygokay commited on
Commit
b920029
1 Parent(s): c594645

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -7,8 +7,18 @@ import os
7
 
8
  def get_image(img, mask=False):
9
  if mask:
 
 
 
 
10
  return np.where(img > 127, 1, 0)
11
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype('double') / 255.0
 
 
 
 
 
 
12
 
13
  def neighbours(i, j, max_i, max_j):
14
  pairs = []
@@ -172,6 +182,10 @@ def blend_images(bg_img, obj_img, mask_img, blend_method):
172
  obj_img = get_image(obj_img)
173
  mask_img = get_image(mask_img, mask=True)
174
 
 
 
 
 
175
  # Resize mask to match object image size
176
  mask_img = cv2.resize(mask_img, (obj_img.shape[1], obj_img.shape[0]))
177
 
 
7
 
8
  def get_image(img, mask=False):
9
  if mask:
10
+ if isinstance(img, str):
11
+ img = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
12
+ elif img.ndim == 3:
13
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
14
  return np.where(img > 127, 1, 0)
15
+ else:
16
+ if isinstance(img, str):
17
+ img = cv2.imread(img)
18
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
+ elif img.ndim == 2:
20
+ img = np.stack((img,)*3, axis=-1)
21
+ return img.astype('double') / 255.0
22
 
23
  def neighbours(i, j, max_i, max_j):
24
  pairs = []
 
182
  obj_img = get_image(obj_img)
183
  mask_img = get_image(mask_img, mask=True)
184
 
185
+ # Ensure mask is 2D
186
+ if mask_img.ndim == 3:
187
+ mask_img = mask_img[:,:,0] # Take the first channel if it's 3D
188
+
189
  # Resize mask to match object image size
190
  mask_img = cv2.resize(mask_img, (obj_img.shape[1], obj_img.shape[0]))
191