MnLgt commited on
Commit
739d5db
·
1 Parent(s): d3966b5

"updated app Bodymask and utils"

Browse files
Files changed (3) hide show
  1. app.py +2 -0
  2. yolo/BodyMask.py +49 -105
  3. yolo/utils.py +17 -10
app.py CHANGED
@@ -86,6 +86,8 @@ def display_image_with_masks(image, results, cols=4):
86
 
87
  def perform_segmentation(input_image):
88
  bm = BodyMask(input_image, model_id=model_id, resize_to=640)
 
 
89
  results = bm.results
90
  buf = display_image_with_masks(input_image, results)
91
 
 
86
 
87
  def perform_segmentation(input_image):
88
  bm = BodyMask(input_image, model_id=model_id, resize_to=640)
89
+ if bm.body_mask is None:
90
+ return input_image # Return the original image if no mask is found
91
  results = bm.results
92
  buf = display_image_with_masks(input_image, results)
93
 
yolo/BodyMask.py CHANGED
@@ -89,7 +89,6 @@ body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
89
 
90
 
91
  class BodyMask:
92
-
93
  def __init__(
94
  self,
95
  image_path,
@@ -118,29 +117,11 @@ class BodyMask:
118
  self.results = self.get_results()
119
  self.dilate_factor = dilate_factor
120
  self.body_mask = self.get_body_mask()
121
- self.box = get_bounding_box(self.body_mask)
122
  self.body_box = self.get_body_box(
123
  remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
124
  )
125
- if self.body_mask is not None:
126
- self.box = get_bounding_box(self.body_mask)
127
- self.body_box = self.get_body_box(
128
- remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
129
- )
130
- if overlay == "box":
131
- self.overlay = overlay_mask(
132
- self.image, self.body_box, opacity=0.9, color="red"
133
- )
134
- else:
135
- self.overlay = overlay_mask(
136
- self.image, self.body_mask, opacity=0.9, color="red"
137
- )
138
- else:
139
- self.box = None
140
- self.body_box = None
141
- self.overlay = (
142
- self.image
143
- ) # Just return the original image if no mask is found
144
 
145
  def get_image(self, resize_to, resize_to_nearest_eight):
146
  image = load_image(self.image_path)
@@ -148,61 +129,59 @@ class BodyMask:
148
  image = resize_preserve_aspect_ratio(image, resize_to)
149
  if resize_to_nearest_eight:
150
  image = resize_image_to_nearest_eight(image)
151
- else:
152
- image = image
153
  return image
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def get_body_mask(self):
156
  body_mask = combine_masks(self.results, self.labels, self.is_label)
157
  if body_mask is not None:
158
  return dilate_mask(body_mask, self.dilate_factor)
159
  return None
160
 
161
- def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
162
  if self.body_mask is None:
163
  return None
 
 
 
 
 
164
  body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
165
- if remove_overlap:
166
  body_box = self.remove_overlap(body_box)
167
  return body_box
168
 
169
- def get_results(self):
170
- imgsz = max(self.image.size)
171
- results = self.model(
172
- self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
173
- )[0]
174
- self.masks, self.boxes, self.scores, self.phrases = unload(
175
- results, self.model_labels
176
- )
177
- results = format_results(
178
- self.masks,
179
- self.boxes,
180
- self.scores,
181
- self.phrases,
182
- self.model_labels,
183
- person_masks_only=False,
184
- )
185
-
186
- # filter out lower score results
187
- masks_to_filter = ["hair"]
188
- results = filter_highest_score(results, ["hair", "face", "phone"])
189
- return results
190
-
191
- def display_results(self):
192
- if len(self.masks) < 4:
193
- cols = len(self.masks)
194
- else:
195
- cols = 4
196
- display_image_with_masks(self.image, self.results, cols=cols)
197
 
198
- def get_mask(self, mask_label):
199
- assert mask_label in self.phrases, "Mask label not found in results"
200
- return [f for f in self.results if f.get("label") == mask_label]
 
 
 
 
 
 
 
201
 
202
  def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
203
- """
204
- Combine the masks included in the labels list or all of the masks not in the list
205
- """
206
  if not is_label:
207
  mask_labels = [
208
  phrase for phrase in self.phrases if phrase not in mask_labels
@@ -217,50 +196,15 @@ class BodyMask:
217
  combined_mask = ImageChops.lighter(combined_mask, mask)
218
  return combined_mask
219
 
220
- def remove_overlap(self, body_box):
221
- """
222
- Remove mask regions that overlap with unwanted labels
223
- """
224
- # convert mask to numpy array
225
- box_array = np.array(body_box)
226
-
227
- # combine the masks for those labels
228
- mask = self.combine_masks(mask_labels=self.labels, is_label=True)
229
-
230
- # convert mask to numpy array
231
- mask_array = np.array(mask)
232
-
233
- # where the mask array is white set the box array to black
234
- box_array[mask_array == 255] = 0
235
-
236
- # convert the box array to an image
237
- mask_image = Image.fromarray(box_array)
238
- return mask_image
239
-
240
-
241
- if __name__ == "__main__":
242
- url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
243
- image_name = url.split("/")[-1]
244
- labels = ["face", "hair", "phone", "hand"]
245
- image = load_image(url)
246
- image_size = image.size
247
- # Get the original size of the image
248
- original_size = image.size
249
-
250
- # Create body mask
251
- body_mask = BodyMask(
252
- image,
253
- overlay="box",
254
- labels=labels,
255
- widen_box=50,
256
- elongate_box=10,
257
- dilate_factor=0,
258
- resize_to=640,
259
- is_label=False,
260
- remove_overlap=True,
261
- verbose=False,
262
- )
263
 
264
- # Resize the image back to the original size
265
- image = body_mask.image.resize(original_size)
266
- body_mask.body_box.save(image_name)
 
 
 
89
 
90
 
91
  class BodyMask:
 
92
  def __init__(
93
  self,
94
  image_path,
 
117
  self.results = self.get_results()
118
  self.dilate_factor = dilate_factor
119
  self.body_mask = self.get_body_mask()
120
+ self.box = self.get_bounding_box()
121
  self.body_box = self.get_body_box(
122
  remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
123
  )
124
+ self.overlay = self.create_overlay(overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def get_image(self, resize_to, resize_to_nearest_eight):
127
  image = load_image(self.image_path)
 
129
  image = resize_preserve_aspect_ratio(image, resize_to)
130
  if resize_to_nearest_eight:
131
  image = resize_image_to_nearest_eight(image)
 
 
132
  return image
133
 
134
+ def get_results(self):
135
+ imgsz = max(self.image.size)
136
+ results = self.model(
137
+ self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
138
+ )[0]
139
+ masks, boxes, scores, phrases = unload(results, self.model_labels)
140
+ results = format_results(
141
+ masks, boxes, scores, phrases, self.model_labels, person_masks_only=False
142
+ )
143
+ masks_to_filter = ["hair"]
144
+ results = filter_highest_score(results, ["hair", "face", "phone"])
145
+ return results
146
+
147
  def get_body_mask(self):
148
  body_mask = combine_masks(self.results, self.labels, self.is_label)
149
  if body_mask is not None:
150
  return dilate_mask(body_mask, self.dilate_factor)
151
  return None
152
 
153
+ def get_bounding_box(self):
154
  if self.body_mask is None:
155
  return None
156
+ return get_bounding_box(self.body_mask)
157
+
158
+ def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
159
+ if self.body_mask is None or self.box is None:
160
+ return None
161
  body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
162
+ if remove_overlap and body_box is not None:
163
  body_box = self.remove_overlap(body_box)
164
  return body_box
165
 
166
+ def create_overlay(self, overlay_type):
167
+ if self.body_box is not None and overlay_type == "box":
168
+ return overlay_mask(self.image, self.body_box, opacity=0.9, color="red")
169
+ elif self.body_mask is not None:
170
+ return overlay_mask(self.image, self.body_mask, opacity=0.9, color="red")
171
+ return self.image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ def remove_overlap(self, body_box):
174
+ if body_box is None:
175
+ return None
176
+ box_array = np.array(body_box)
177
+ mask = self.combine_masks(mask_labels=self.labels, is_label=True)
178
+ if mask is None:
179
+ return body_box
180
+ mask_array = np.array(mask)
181
+ box_array[mask_array == 255] = 0
182
+ return Image.fromarray(box_array)
183
 
184
  def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
 
 
 
185
  if not is_label:
186
  mask_labels = [
187
  phrase for phrase in self.phrases if phrase not in mask_labels
 
196
  combined_mask = ImageChops.lighter(combined_mask, mask)
197
  return combined_mask
198
 
199
+ def display_results(self):
200
+ if not self.results:
201
+ print("No results to display.")
202
+ return
203
+ cols = min(len(self.results), 4)
204
+ display_image_with_masks(self.image, self.results, cols=cols)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ def get_mask(self, mask_label):
207
+ if mask_label not in self.phrases:
208
+ print(f"Mask label '{mask_label}' not found in results.")
209
+ return None
210
+ return [f for f in self.results if f.get("label") == mask_label]
yolo/utils.py CHANGED
@@ -178,16 +178,23 @@ def display_image_with_masks(image, results, cols=4, return_images=False):
178
 
179
 
180
  def get_bounding_box(mask):
181
- """
182
- Given a segmentation mask, return the bounding box for the mask object.
183
- """
184
- # Find indices where the mask is non-zero
185
- coords = np.argwhere(mask)
186
- # Get the minimum and maximum x and y coordinates
187
- x_min, y_min = np.min(coords, axis=0)
188
- x_max, y_max = np.max(coords, axis=0)
189
- # Return the bounding box coordinates
190
- return (y_min, x_min, y_max, x_max)
 
 
 
 
 
 
 
191
 
192
 
193
  def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
 
178
 
179
 
180
  def get_bounding_box(mask):
181
+ if mask is None or not isinstance(mask, np.ndarray):
182
+ return None
183
+
184
+ # Check if the mask is empty
185
+ if mask.size == 0 or np.all(mask == 0):
186
+ return None
187
+
188
+ # Find the bounding box
189
+ rows = np.any(mask, axis=1)
190
+ cols = np.any(mask, axis=0)
191
+ if not np.any(rows) or not np.any(cols):
192
+ return None
193
+
194
+ rmin, rmax = np.where(rows)[0][[0, -1]]
195
+ cmin, cmax = np.where(cols)[0][[0, -1]]
196
+
197
+ return (int(cmin), int(rmin), int(cmax), int(rmax))
198
 
199
 
200
  def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):