File size: 2,077 Bytes
c0827b5
 
 
 
d9dc7e6
 
 
 
c0827b5
d9dc7e6
 
 
c1b3d98
d9dc7e6
6ca5592
d9dc7e6
 
 
 
 
 
c0827b5
 
 
 
 
 
 
 
 
 
d9dc7e6
 
 
c0827b5
6ca5592
 
 
 
c0827b5
d9dc7e6
 
c0827b5
d9dc7e6
c0827b5
 
 
d9dc7e6
c0827b5
d9dc7e6
6ca5592
c0827b5
6ca5592
c0827b5
d9dc7e6
c0827b5
6ca5592
c0827b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from gradio_client import Client, handle_file
from PIL import Image
import requests
from io import BytesIO

def get_segmentation_mask(image_url):
    client = Client("facebook/sapiens-seg")
    result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
    return np.load(result[1])  # Result[2] contains the .npy mask

def process_image(image, categories_to_hide):
    # Convert uploaded image to a PIL Image
    image = Image.open(image.name).convert("RGB")
    
    # Save temporarily and get the mask
    image.save("temp_image.png")
    mask_data = get_segmentation_mask("temp_image.png")
    
    # Define grouped categories
    grouped_mapping = {
        "Background": [0],
        "Clothes": [1, 12, 22, 8, 9, 17, 18],  # Includes Shoes, Socks, Slippers
        "Face": [2, 23, 24, 25, 26, 27],  # Face Neck, Lips, Teeth, Tongue
        "Hair": [3],  # Hair
        "Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21]  # Hands, Feet, Arms, Legs, Torso
    }

    # Apply the mask over the original image
    image_array = np.array(image)
    masked_image = image_array.copy()
    
    # Black out selected categories
    for category in categories_to_hide:
        for idx in grouped_mapping.get(category, []):
            masked_image[mask_data == idx] = [0, 0, 0]
    
    # Convert back to PIL Image
    result_image = Image.fromarray(masked_image)
    
    return result_image

# Define Gradio Interface
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.File(label="Upload an Image"),
        gr.CheckboxGroup([
            "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
        ], label="Select Categories to Hide")
    ],
    outputs=gr.Image(label="Masked Image"),
    title="Segmentation Mask Editor",
    description="Upload an image, generate a segmentation mask, and select categories to black out."
)

if __name__ == "__main__":
    demo.launch()