segment-hair / app.py
emirhanno's picture
bugfix
fd93f98
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat
# Constants for colors
BG_COLOR = (0, 0, 0, 255) # gray with full opacity
MASK_COLOR = (255, 255, 255, 255) # white with full opacity
# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
output_category_mask=True)
# Function to segment hair and generate mask
def segment_hair(image):
rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
rgba_image[:, :, 3] = 0 # Set alpha channel to empty
# Create MP Image object from numpy array
mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)
# Create the image segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:
# Retrieve the masks for the segmented image
segmentation_result = segmenter.segment(mp_image)
category_mask = segmentation_result.category_mask
# Generate solid color images for showing the output segmentation mask.
image_data = mp_image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
return cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
# Gradio interface
iface = gr.Interface(
fn=segment_hair,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="numpy"),
title="Hair Segmentation",
description="Upload an image to segment the hair and generate a mask.",
examples=["example.jpeg"]
)
if __name__ == "__main__":
iface.launch()