Saad0KH commited on
Commit
61a0586
Β·
verified Β·
1 Parent(s): 8414811

Update SegCloth.py

Browse files
Files changed (1) hide show
  1. SegCloth.py +17 -19
SegCloth.py CHANGED
@@ -1,42 +1,40 @@
 
1
  from PIL import Image
2
  import numpy as np
3
  from io import BytesIO
4
  import io
5
  import base64
6
 
7
-
8
-
9
  # Initialize segmentation pipeline
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
 
12
-
13
  def encode_image_to_base64(image):
14
  buffered = BytesIO()
15
  image.save(buffered, format="PNG")
16
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
17
 
18
- def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Scarf"]):
19
  # Segment image
20
  segments = segmenter(img)
21
 
22
- # Convert image to RGBA
23
- img = img.convert("RGBA")
24
-
25
  # Create list of masks
26
- result_images = []
27
  for s in segments:
28
- if s['label'] in clothes:
29
- final_mask = Image.fromarray(current_mask)
30
- resized_img.putalpha(final_mask)
31
 
 
32
 
33
-
34
-
35
- # Convert the final image to base64
36
- imageBase64 = encode_image_to_base64(resized_img)
37
- result_images.append((s['label'], imageBase64))
38
-
39
- return result_images
 
40
 
41
 
42
-
 
 
 
1
+ from transformers import pipeline
2
  from PIL import Image
3
  import numpy as np
4
  from io import BytesIO
5
  import io
6
  import base64
7
 
 
 
8
  # Initialize segmentation pipeline
9
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
10
 
 
11
  def encode_image_to_base64(image):
12
  buffered = BytesIO()
13
  image.save(buffered, format="PNG")
14
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
15
 
16
+ def segment_clothing(img, clothes= ["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
17
  # Segment image
18
  segments = segmenter(img)
19
 
 
 
 
20
  # Create list of masks
21
+ mask_list = []
22
  for s in segments:
23
+ if(s['label'] in clothes):
24
+ mask_list.append(s['mask'])
 
25
 
26
+ result_images = []
27
 
28
+ # Paste all masks on top of eachother
29
+ final_mask = np.array(mask_list[0])
30
+ for mask in mask_list:
31
+ current_mask = np.array(mask)
32
+ final_mask_bis = Image.fromarray(current_mask)
33
+ img.putalpha(final_mask_bis)
34
+ imageBase64 = encode_image_to_base64(img)
35
+ result_images.append(('clothing_type', imageBase64))
36
 
37
 
38
+
39
+
40
+ return result_images