Saad0KH commited on
Commit
ff7c6c7
Β·
verified Β·
1 Parent(s): ce48cb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -3
app.py CHANGED
@@ -11,6 +11,11 @@ from SegCloth import segment_clothing
11
  from transparent_background import Remover
12
  import threading
13
  import logging
 
 
 
 
 
14
 
15
  app = Flask(__name__)
16
 
@@ -35,6 +40,22 @@ def load_model():
35
  detector = model
36
  logging.info("Model loaded successfully.")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def save_image(img):
39
  unique_name = str(uuid.uuid4()) + ".png"
40
  img.save(unique_name)
@@ -51,7 +72,24 @@ def encode_image_to_base64(image):
51
  buffered = BytesIO()
52
  image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
53
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @spaces.GPU
56
  def remove_background(image):
57
  remover = Remover()
@@ -73,7 +111,7 @@ def detect_and_segment_persons(image, clothes):
73
 
74
  bboxes, kpss = detector.detect(img)
75
  if bboxes.shape[0] == 0:
76
- return [save_image(remove_background(image))]
77
 
78
  height, width, _ = img.shape
79
  bboxes = np.round(bboxes[:, :4]).astype(int)
@@ -89,7 +127,7 @@ def detect_and_segment_persons(image, clothes):
89
  person_img = img[y1:y2, x1:x2]
90
  pil_img = Image.fromarray(person_img[:, :, ::-1])
91
 
92
- img_rm_background = remove_background(pil_img)
93
  segmented_result = segment_clothing(img_rm_background, clothes)
94
  image_paths = [save_image(img) for img in segmented_result]
95
  print(image_paths)
 
11
  from transparent_background import Remover
12
  import threading
13
  import logging
14
+ import uuid
15
+ from transformers import AutoModelForImageSegmentation
16
+ import torch
17
+ from torchvision import transforms
18
+
19
 
20
  app = Flask(__name__)
21
 
 
40
  detector = model
41
  logging.info("Model loaded successfully.")
42
 
43
+ torch.set_float32_matmul_precision(["high", "highest"][0])
44
+
45
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
46
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
47
+ )
48
+ birefnet.to("cuda")
49
+ transform_image = transforms.Compose(
50
+ [
51
+ transforms.Resize((1024, 1024)),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
54
+ ]
55
+ )
56
+
57
+
58
+
59
  def save_image(img):
60
  unique_name = str(uuid.uuid4()) + ".png"
61
  img.save(unique_name)
 
72
  buffered = BytesIO()
73
  image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
74
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
75
+
76
+ @spaces.GPU
77
+ def rm_background(image):
78
+ im = load_img(image, output_type="pil")
79
+ im = im.convert("RGB")
80
+ image_size = im.size
81
+ origin = im.copy()
82
+ image = load_img(im)
83
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
84
+ # Prediction
85
+ with torch.no_grad():
86
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
87
+ pred = preds[0].squeeze()
88
+ pred_pil = transforms.ToPILImage()(pred)
89
+ mask = pred_pil.resize(image_size)
90
+ image.putalpha(mask)
91
+ return (image)
92
+
93
  @spaces.GPU
94
  def remove_background(image):
95
  remover = Remover()
 
111
 
112
  bboxes, kpss = detector.detect(img)
113
  if bboxes.shape[0] == 0:
114
+ return [save_image(rm_background(image))]
115
 
116
  height, width, _ = img.shape
117
  bboxes = np.round(bboxes[:, :4]).astype(int)
 
127
  person_img = img[y1:y2, x1:x2]
128
  pil_img = Image.fromarray(person_img[:, :, ::-1])
129
 
130
+ img_rm_background = rm_background(pil_img)
131
  segmented_result = segment_clothing(img_rm_background, clothes)
132
  image_paths = [save_image(img) for img in segmented_result]
133
  print(image_paths)