Saad0KH commited on
Commit
0bac47f
Β·
verified Β·
1 Parent(s): b33f918

Update SegCloth.py

Browse files
Files changed (1) hide show
  1. SegCloth.py +14 -24
SegCloth.py CHANGED
@@ -1,52 +1,42 @@
1
- from transformers import pipeline
2
  from PIL import Image
3
  import numpy as np
4
  from io import BytesIO
 
5
  import base64
6
- import torch
7
- from torchvision.transforms.functional import to_pil_image
8
 
9
- # Initialize segmentation and super-resolution pipelines
 
 
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
- super_res = pipeline("image-super-resolution", model="CompVis/stable-diffusion-v-1-4")
12
 
13
  def encode_image_to_base64(image):
14
  buffered = BytesIO()
15
  image.save(buffered, format="PNG")
16
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
17
 
18
- def segment_and_enhance_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
19
  # Segment image
20
  segments = segmenter(img)
21
 
22
  # Convert image to RGBA
23
  img = img.convert("RGBA")
24
 
 
25
  result_images = []
26
  for s in segments:
27
  if s['label'] in clothes:
28
- # Extract mask and resize image to mask size
29
- current_mask = np.array(s['mask'])
30
- mask_size = current_mask.shape[::-1] # Mask size is (width, height)
31
-
32
- # Resize the original image to match the mask size
33
- resized_img = img.resize(mask_size)
34
-
35
- # Apply mask to resized image
36
  final_mask = Image.fromarray(current_mask)
37
  resized_img.putalpha(final_mask)
38
 
39
- # Enhance image using super-resolution
40
- enhanced_img = super_res(resized_img)
41
-
42
  # Convert the final image to base64
43
- imageBase64 = encode_image_to_base64(enhanced_img)
44
  result_images.append((s['label'], imageBase64))
45
 
46
- return result_images
 
47
 
48
- # Example usage
49
- # img = Image.open('your_image_path_here.png')
50
- # result_images = segment_and_enhance_clothing(img)
51
- # for clothing_type, image_base64 in result_images:
52
- # print(clothing_type, image_base64)
 
 
1
  from PIL import Image
2
  import numpy as np
3
  from io import BytesIO
4
+ import io
5
  import base64
 
 
6
 
7
+
8
+
9
+ # Initialize segmentation pipeline
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
+
12
 
13
  def encode_image_to_base64(image):
14
  buffered = BytesIO()
15
  image.save(buffered, format="PNG")
16
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
17
 
18
+ def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Scarf"]):
19
  # Segment image
20
  segments = segmenter(img)
21
 
22
  # Convert image to RGBA
23
  img = img.convert("RGBA")
24
 
25
+ # Create list of masks
26
  result_images = []
27
  for s in segments:
28
  if s['label'] in clothes:
 
 
 
 
 
 
 
 
29
  final_mask = Image.fromarray(current_mask)
30
  resized_img.putalpha(final_mask)
31
 
32
+
33
+
34
+
35
  # Convert the final image to base64
36
+ imageBase64 = encode_image_to_base64(resized_img)
37
  result_images.append((s['label'], imageBase64))
38
 
39
+ return result_images
40
+
41
 
42
+