MnLgt commited on
Commit
d3966b5
·
1 Parent(s): 88d8a55

fixed body mask

Browse files
Files changed (1) hide show
  1. yolo/BodyMask.py +31 -13
yolo/BodyMask.py CHANGED
@@ -61,6 +61,9 @@ def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.
61
  """
62
  labels_set = set(labels) # Convert labels list to a set for O(1) lookups
63
 
 
 
 
64
  # Filter and convert mask images based on the specified labels
65
  mask_images = [
66
  mask["mask"].convert("L")
@@ -119,14 +122,25 @@ class BodyMask:
119
  self.body_box = self.get_body_box(
120
  remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
121
  )
122
- if overlay == "box":
123
- self.overlay = overlay_mask(
124
- self.image, self.body_box, opacity=0.9, color="red"
 
125
  )
 
 
 
 
 
 
 
 
126
  else:
127
- self.overlay = overlay_mask(
128
- self.image, self.body_mask, opacity=0.9, color="red"
129
- )
 
 
130
 
131
  def get_image(self, resize_to, resize_to_nearest_eight):
132
  image = load_image(self.image_path)
@@ -140,7 +154,17 @@ class BodyMask:
140
 
141
  def get_body_mask(self):
142
  body_mask = combine_masks(self.results, self.labels, self.is_label)
143
- return dilate_mask(body_mask, self.dilate_factor)
 
 
 
 
 
 
 
 
 
 
144
 
145
  def get_results(self):
146
  imgsz = max(self.image.size)
@@ -193,12 +217,6 @@ class BodyMask:
193
  combined_mask = ImageChops.lighter(combined_mask, mask)
194
  return combined_mask
195
 
196
- def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
197
- body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
198
- if remove_overlap:
199
- body_box = self.remove_overlap(body_box)
200
- return body_box
201
-
202
  def remove_overlap(self, body_box):
203
  """
204
  Remove mask regions that overlap with unwanted labels
 
61
  """
62
  labels_set = set(labels) # Convert labels list to a set for O(1) lookups
63
 
64
+ # Filter out any masks that do not have a label key
65
+ masks = [mask for mask in masks if "label" in mask]
66
+
67
  # Filter and convert mask images based on the specified labels
68
  mask_images = [
69
  mask["mask"].convert("L")
 
122
  self.body_box = self.get_body_box(
123
  remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
124
  )
125
+ if self.body_mask is not None:
126
+ self.box = get_bounding_box(self.body_mask)
127
+ self.body_box = self.get_body_box(
128
+ remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
129
  )
130
+ if overlay == "box":
131
+ self.overlay = overlay_mask(
132
+ self.image, self.body_box, opacity=0.9, color="red"
133
+ )
134
+ else:
135
+ self.overlay = overlay_mask(
136
+ self.image, self.body_mask, opacity=0.9, color="red"
137
+ )
138
  else:
139
+ self.box = None
140
+ self.body_box = None
141
+ self.overlay = (
142
+ self.image
143
+ ) # Just return the original image if no mask is found
144
 
145
  def get_image(self, resize_to, resize_to_nearest_eight):
146
  image = load_image(self.image_path)
 
154
 
155
  def get_body_mask(self):
156
  body_mask = combine_masks(self.results, self.labels, self.is_label)
157
+ if body_mask is not None:
158
+ return dilate_mask(body_mask, self.dilate_factor)
159
+ return None
160
+
161
+ def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
162
+ if self.body_mask is None:
163
+ return None
164
+ body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
165
+ if remove_overlap:
166
+ body_box = self.remove_overlap(body_box)
167
+ return body_box
168
 
169
  def get_results(self):
170
  imgsz = max(self.image.size)
 
217
  combined_mask = ImageChops.lighter(combined_mask, mask)
218
  return combined_mask
219
 
 
 
 
 
 
 
220
  def remove_overlap(self, body_box):
221
  """
222
  Remove mask regions that overlap with unwanted labels