Saad0KH commited on
Commit
b33f918
Β·
verified Β·
1 Parent(s): fc3d381

Update SegCloth.py

Browse files
Files changed (1) hide show
  1. SegCloth.py +8 -16
SegCloth.py CHANGED
@@ -1,26 +1,21 @@
1
- import torch
2
  from transformers import pipeline
3
- from diffusers import StableDiffusionPipeline
4
  from PIL import Image
5
  import numpy as np
6
  from io import BytesIO
7
  import base64
 
 
8
 
9
- # Initialize segmentation pipeline
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
-
12
- # Initialize Stable Diffusion pipeline
13
- model_id = "CompVis/stable-diffusion-v1-4"
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
16
- pipe = pipe.to(device)
17
 
18
  def encode_image_to_base64(image):
19
  buffered = BytesIO()
20
  image.save(buffered, format="PNG")
21
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
22
 
23
- def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
24
  # Segment image
25
  segments = segmenter(img)
26
 
@@ -41,12 +36,9 @@ def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dr
41
  final_mask = Image.fromarray(current_mask)
42
  resized_img.putalpha(final_mask)
43
 
44
- # Generate prompt for Stable Diffusion
45
- prompt = f"a clear photo of {s['label']}"
46
-
47
- # Enhance image using Stable Diffusion
48
- enhanced_img = pipe(prompt).images[0]
49
-
50
  # Convert the final image to base64
51
  imageBase64 = encode_image_to_base64(enhanced_img)
52
  result_images.append((s['label'], imageBase64))
 
 
1
  from transformers import pipeline
 
2
  from PIL import Image
3
  import numpy as np
4
  from io import BytesIO
5
  import base64
6
+ import torch
7
+ from torchvision.transforms.functional import to_pil_image
8
 
9
+ # Initialize segmentation and super-resolution pipelines
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
+ super_res = pipeline("image-super-resolution", model="CompVis/stable-diffusion-v-1-4")
 
 
 
 
 
12
 
13
  def encode_image_to_base64(image):
14
  buffered = BytesIO()
15
  image.save(buffered, format="PNG")
16
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
17
 
18
+ def segment_and_enhance_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
19
  # Segment image
20
  segments = segmenter(img)
21
 
 
36
  final_mask = Image.fromarray(current_mask)
37
  resized_img.putalpha(final_mask)
38
 
39
+ # Enhance image using super-resolution
40
+ enhanced_img = super_res(resized_img)
41
+
 
 
 
42
  # Convert the final image to base64
43
  imageBase64 = encode_image_to_base64(enhanced_img)
44
  result_images.append((s['label'], imageBase64))