Jannat24 commited on
Commit
f233ca1
·
verified ·
1 Parent(s): 49e868c

Update modules/segmentface.py

Browse files
Files changed (1) hide show
  1. modules/segmentface.py +35 -38
modules/segmentface.py CHANGED
@@ -1,75 +1,72 @@
1
  import cv2
2
  import mediapipe as mp
3
  import numpy as np
4
- from rembg import remove
5
- from PIL import Image
 
6
 
7
  class FaceSegmenter:
8
  def __init__(self, threshold=0.5):
9
  self.threshold = threshold
10
- # Initialize face detection
11
  self.face_detection = mp.solutions.face_detection.FaceDetection(
12
- model_selection=1, # 1 for general use, 0 for close-up faces
13
- min_detection_confidence=0.5
14
  )
15
- # Initialize selfie segmentation (for background removal)
16
  self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
17
- model_selection=1 # 1 for general use, 0 for close-up faces
18
  )
19
 
20
- def segment_face(self, image_path):
21
- # Load the image
22
- image = cv2.imread(image_path)
23
- if image is None:
24
- raise ValueError("Image not found or unable to load.")
 
 
 
 
 
 
 
25
 
26
- # Convert to RGB (MediaPipe requires RGB input)
27
  rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
-
29
- # Step 1: Detect the face
30
  face_results = self.face_detection.process(rgb_image)
 
31
  if not face_results.detections:
32
- # Use rembg to remove the background
33
- with open(image_path, "rb") as input_file:
34
- input_image = input_file.read()
35
- output_image = remove(input_image)
36
- # Convert the output image to a numpy array
37
- output_image = np.array(Image.open(io.BytesIO(output_image)))
38
- # Convert RGBA to RGB (remove alpha channel)
39
  if output_image.shape[2] == 4:
40
  output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
41
  return output_image
42
 
43
- # Get the bounding box of the first detected face
44
  detection = face_results.detections[0]
45
  bboxC = detection.location_data.relative_bounding_box
46
  h, w, _ = image.shape
47
- x, y, width, height = int(bboxC.xmin * w), int(bboxC.ymin * h), \
48
- int(bboxC.width * w), int(bboxC.height * h)
 
 
49
 
50
- # Step 2: Segment the foreground (selfie segmentation)
51
  segmentation_results = self.selfie_segmentation.process(rgb_image)
52
- if segmentation_results.segmentation_mask is None:
53
- raise ValueError("Segmentation failed.")
54
-
55
- # Create a binary mask
56
  mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
57
-
58
- # Step 3: Crop the face using the bounding box
59
  face_mask = np.zeros_like(mask)
60
  face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
61
-
62
- # Apply the mask to the original image
63
  segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
64
 
65
  return segmented_face
66
 
67
- def save_segmented_face(self, image_path, output_path):
68
- segmented_face = self.segment_face(image_path)
 
69
  cv2.imwrite(output_path, segmented_face)
70
 
71
- def show_segmented_face(self, image_path):
72
- segmented_face = self.segment_face(image_path)
73
  cv2.imshow("Segmented Face", segmented_face)
74
  cv2.waitKey(0)
75
  cv2.destroyAllWindows()
 
1
  import cv2
2
  import mediapipe as mp
3
  import numpy as np
4
+ from rembg import remove
5
+ from PIL import Image
6
+ import io
7
 
8
  class FaceSegmenter:
9
  def __init__(self, threshold=0.5):
10
  self.threshold = threshold
 
11
  self.face_detection = mp.solutions.face_detection.FaceDetection(
12
+ model_selection=1, min_detection_confidence=0.5
 
13
  )
 
14
  self.selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(
15
+ model_selection=1
16
  )
17
 
18
+ def segment_face(self, image_input):
19
+ # Handle both file paths and numpy arrays
20
+ if isinstance(image_input, str):
21
+ # Load from file path
22
+ image = cv2.imread(image_input)
23
+ if image is None:
24
+ raise ValueError("Image not found or unable to load.")
25
+ elif isinstance(image_input, np.ndarray):
26
+ # Use numpy array directly (BGR format)
27
+ image = image_input.copy()
28
+ else:
29
+ raise ValueError("Input must be file path string or numpy array")
30
 
 
31
  rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
32
  face_results = self.face_detection.process(rgb_image)
33
+
34
  if not face_results.detections:
35
+ # Use rembg with numpy array input
36
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
37
+ output_image = remove(pil_image) # rembg handles PIL Images
38
+
39
+ # Convert to numpy array and remove alpha channel
40
+ output_image = np.array(output_image)
 
41
  if output_image.shape[2] == 4:
42
  output_image = cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
43
  return output_image
44
 
45
+ # Existing face segmentation logic
46
  detection = face_results.detections[0]
47
  bboxC = detection.location_data.relative_bounding_box
48
  h, w, _ = image.shape
49
+ x, y, width, height = (
50
+ int(bboxC.xmin * w), int(bboxC.ymin * h),
51
+ int(bboxC.width * w), int(bboxC.height * h)
52
+ )
53
 
 
54
  segmentation_results = self.selfie_segmentation.process(rgb_image)
 
 
 
 
55
  mask = (segmentation_results.segmentation_mask > self.threshold).astype(np.uint8)
56
+
 
57
  face_mask = np.zeros_like(mask)
58
  face_mask[y:y+height, x:x+width] = mask[y:y+height, x:x+width]
 
 
59
  segmented_face = cv2.bitwise_and(image, image, mask=face_mask)
60
 
61
  return segmented_face
62
 
63
+ # Updated helper methods to handle numpy arrays
64
+ def save_segmented_face(self, image_input, output_path):
65
+ segmented_face = self.segment_face(image_input)
66
  cv2.imwrite(output_path, segmented_face)
67
 
68
+ def show_segmented_face(self, image_input):
69
+ segmented_face = self.segment_face(image_input)
70
  cv2.imshow("Segmented Face", segmented_face)
71
  cv2.waitKey(0)
72
  cv2.destroyAllWindows()