Saad0KH commited on
Commit
95ddb06
Β·
verified Β·
1 Parent(s): 6d8b58d

Update SegCloth.py

Browse files
Files changed (1) hide show
  1. SegCloth.py +15 -7
SegCloth.py CHANGED
@@ -1,14 +1,19 @@
 
1
  from transformers import pipeline
 
2
  from PIL import Image
3
  import numpy as np
4
  from io import BytesIO
5
  import base64
6
- import torch
7
- from torchvision.transforms.functional import to_pil_image
8
 
9
- # Initialize segmentation and super-resolution pipelines
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
- super_res = pipeline("image-super-resolution", model="CompVis/stable-diffusion-v1-4")
 
 
 
 
 
12
 
13
  def encode_image_to_base64(image):
14
  buffered = BytesIO()
@@ -36,9 +41,12 @@ def segment_and_enhance_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt",
36
  final_mask = Image.fromarray(current_mask)
37
  resized_img.putalpha(final_mask)
38
 
39
- # Enhance image using super-resolution
40
- enhanced_img = super_res(resized_img)
41
-
 
 
 
42
  # Convert the final image to base64
43
  imageBase64 = encode_image_to_base64(enhanced_img)
44
  result_images.append((s['label'], imageBase64))
 
1
+ import torch
2
  from transformers import pipeline
3
+ from diffusers import StableDiffusionPipeline
4
  from PIL import Image
5
  import numpy as np
6
  from io import BytesIO
7
  import base64
 
 
8
 
9
+ # Initialize segmentation pipeline
10
  segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
11
+
12
+ # Initialize Stable Diffusion pipeline
13
+ model_id = "CompVis/stable-diffusion-v1-4"
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
16
+ pipe = pipe.to(device)
17
 
18
  def encode_image_to_base64(image):
19
  buffered = BytesIO()
 
41
  final_mask = Image.fromarray(current_mask)
42
  resized_img.putalpha(final_mask)
43
 
44
+ # Generate prompt for Stable Diffusion
45
+ prompt = f"a clear photo of {s['label']}"
46
+
47
+ # Enhance image using Stable Diffusion
48
+ enhanced_img = pipe(prompt).images[0]
49
+
50
  # Convert the final image to base64
51
  imageBase64 = encode_image_to_base64(enhanced_img)
52
  result_images.append((s['label'], imageBase64))