Saad0KH commited on
Commit
549db45
·
verified ·
1 Parent(s): ab47e20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -33
app.py CHANGED
@@ -3,17 +3,25 @@ from PIL import Image
3
  import base64
4
  from io import BytesIO
5
  import numpy as np
6
- import cv2
7
  import insightface
8
  import onnxruntime as ort
9
  import huggingface_hub
10
  from SegCloth import segment_clothing
11
  from transparent_background import Remover
 
 
12
 
13
  app = Flask(__name__)
14
 
15
- # Charger le modèle
 
 
 
 
 
 
16
  def load_model():
 
17
  path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
18
  options = ort.SessionOptions()
19
  options.intra_op_num_threads = 8
@@ -22,24 +30,22 @@ def load_model():
22
  path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
23
  )
24
  model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
25
- return model
26
-
27
- detector = load_model()
28
- detector.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
29
 
30
- # Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image
31
  def decode_image_from_base64(image_data):
32
  image_data = base64.b64decode(image_data)
33
- image = Image.open(BytesIO(image_data)).convert("RGB") # Convertir en RGB pour éviter les problèmes de canal alpha
34
  return image
35
 
36
- # Fonction pour encoder une image PIL en base64
37
  def encode_image_to_base64(image):
38
  buffered = BytesIO()
39
- image.save(buffered, format="PNG")
40
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
41
 
42
- #@spaces.GPU
43
  def remove_background(image):
44
  remover = Remover()
45
  if isinstance(image, Image.Image):
@@ -51,46 +57,41 @@ def remove_background(image):
51
  raise TypeError("Unsupported image type")
52
  return output
53
 
54
- # Détecter les personnes et segmenter leurs vêtements
55
  def detect_and_segment_persons(image, clothes):
56
  img = np.array(image)
57
  img = img[:, :, ::-1] # RGB -> BGR
58
 
 
 
 
59
  bboxes, kpss = detector.detect(img)
60
- if bboxes.shape[0] == 0: # Aucun visage détecté
61
  return [encode_image_to_base64(remove_background(image))]
62
 
63
- height, width, _ = img.shape # Get image dimensions
64
-
65
  bboxes = np.round(bboxes[:, :4]).astype(int)
66
-
67
- # Clamp bounding boxes within image boundaries
68
- bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) # x1
69
- bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) # y1
70
- bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) # x2
71
- bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) # y2
72
 
73
  all_segmented_images = []
74
  for i in range(bboxes.shape[0]):
75
  bbox = bboxes[i]
76
  x1, y1, x2, y2 = bbox
77
  person_img = img[y1:y2, x1:x2]
 
78
 
79
- # Convert numpy array to PIL Image
80
- pil_img = Image.fromarray(person_img[:, :, ::-1]) # BGR -> RGB
81
-
82
- # Segment clothing for the detected person
83
  img_rm_background = remove_background(pil_img)
84
  segmented_result = segment_clothing(img_rm_background, clothes)
85
 
86
- # Combine the segmented images for all persons
87
  all_segmented_images.extend(segmented_result)
88
 
89
  return all_segmented_images
90
 
91
  @app.route('/', methods=['GET'])
92
  def welcome():
93
- return "Welcome to Clothing Segmentation API"
94
 
95
  @app.route('/api/detect', methods=['POST'])
96
  def detect():
@@ -98,15 +99,24 @@ def detect():
98
  data = request.json
99
  image_base64 = data['image']
100
  image = decode_image_from_base64(image_base64)
101
-
102
- # Détection et segmentation des personnes
103
  clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
104
- person_images_base64 = detect_and_segment_persons(image, clothes)
105
 
106
- return jsonify({'images': person_images_base64})
 
 
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
- print(e)
109
  return jsonify({'error': str(e)}), 500
110
 
111
  if __name__ == "__main__":
112
- app.run(debug=True, host="0.0.0.0", port=7860)
 
3
  import base64
4
  from io import BytesIO
5
  import numpy as np
 
6
  import insightface
7
  import onnxruntime as ort
8
  import huggingface_hub
9
  from SegCloth import segment_clothing
10
  from transparent_background import Remover
11
+ import threading
12
+ import logging
13
 
14
  app = Flask(__name__)
15
 
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ # Load the model lazily
20
+ model = None
21
+ detector = None
22
+
23
  def load_model():
24
+ global model, detector
25
  path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
26
  options = ort.SessionOptions()
27
  options.intra_op_num_threads = 8
 
30
  path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
31
  )
32
  model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
33
+ model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
34
+ detector = model
35
+ logging.info("Model loaded successfully.")
 
36
 
37
+ # Function to decode a base64 image to PIL.Image.Image
38
  def decode_image_from_base64(image_data):
39
  image_data = base64.b64decode(image_data)
40
+ image = Image.open(BytesIO(image_data)).convert("RGB")
41
  return image
42
 
43
+ # Function to encode a PIL image to base64
44
  def encode_image_to_base64(image):
45
  buffered = BytesIO()
46
+ image.save(buffered, format="JPEG") # Use JPEG for potentially better performance
47
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
48
 
 
49
  def remove_background(image):
50
  remover = Remover()
51
  if isinstance(image, Image.Image):
 
57
  raise TypeError("Unsupported image type")
58
  return output
59
 
 
60
  def detect_and_segment_persons(image, clothes):
61
  img = np.array(image)
62
  img = img[:, :, ::-1] # RGB -> BGR
63
 
64
+ if detector is None:
65
+ load_model() # Ensure the model is loaded
66
+
67
  bboxes, kpss = detector.detect(img)
68
+ if bboxes.shape[0] == 0:
69
  return [encode_image_to_base64(remove_background(image))]
70
 
71
+ height, width, _ = img.shape
 
72
  bboxes = np.round(bboxes[:, :4]).astype(int)
73
+ bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
74
+ bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
75
+ bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
76
+ bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
 
 
77
 
78
  all_segmented_images = []
79
  for i in range(bboxes.shape[0]):
80
  bbox = bboxes[i]
81
  x1, y1, x2, y2 = bbox
82
  person_img = img[y1:y2, x1:x2]
83
+ pil_img = Image.fromarray(person_img[:, :, ::-1])
84
 
 
 
 
 
85
  img_rm_background = remove_background(pil_img)
86
  segmented_result = segment_clothing(img_rm_background, clothes)
87
 
 
88
  all_segmented_images.extend(segmented_result)
89
 
90
  return all_segmented_images
91
 
92
  @app.route('/', methods=['GET'])
93
  def welcome():
94
+ return "Welcome to Clothing Segmentation API"
95
 
96
  @app.route('/api/detect', methods=['POST'])
97
  def detect():
 
99
  data = request.json
100
  image_base64 = data['image']
101
  image = decode_image_from_base64(image_base64)
102
+
 
103
  clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
 
104
 
105
+ # Run the detection and segmentation in a separate thread
106
+ result = []
107
+
108
+ def process_image():
109
+ nonlocal result
110
+ result = detect_and_segment_persons(image, clothes)
111
+
112
+ thread = threading.Thread(target=process_image)
113
+ thread.start()
114
+ thread.join() # Wait for the thread to finish
115
+
116
+ return jsonify({'images': result})
117
  except Exception as e:
118
+ logging.error(f"Error occurred: {e}")
119
  return jsonify({'error': str(e)}), 500
120
 
121
  if __name__ == "__main__":
122
+ app.run(debug=True, host="0.0.0.0", port=7860)