File size: 3,882 Bytes
7cddaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
import cv2
from PIL import Image
import numpy as np
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import spaces  # Import ZeroGPU support

# Detect if CUDA is available; otherwise, fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load BiRefNet model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)

# Image transformation pipeline
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU(duration=70)  # Decorate to ensure GPU is allocated only during model loading

# Function to extract the subject using BiRefNet and create a mask
def create_mask(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()  # Always move results to CPU for processing
    
    pred = preds[0].squeeze()
    mask_pil = transforms.ToPILImage()(pred)
    mask = mask_pil.resize(image_size)
    
    return mask

# Function to apply the pink filter-like color change
def apply_filter(image, mask=None, apply_to_subject=True):
    # Convert image to numpy array
    image_np = np.array(image.convert("RGBA"))
    
    # Define the pink color in RGBA
    pink_color = np.array([255, 0, 255, 128])  # Pink color with transparency (alpha = 128)

    if apply_to_subject and mask is not None:
        # Convert mask to numpy array
        mask_np = np.array(mask)

        # Blend the original image with the pink color where the mask is applied
        for i in range(image_np.shape[0]):
            for j in range(image_np.shape[1]):
                if mask_np[i, j] > 128:  # Check if the mask value indicates subject presence
                    image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8)
    else:
        # Apply the pink filter to the whole image if no subject is detected or if chosen by user
        image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8)

    # Convert back to PIL image
    result_image = Image.fromarray(image_np)

    return result_image

# Main processing function for Gradio
def process(input_image, subject_choice):
    if input_image is None:
        raise gr.Error('Please upload an input image')

    # Convert input image to PIL image
    original_image = Image.fromarray(input_image)

    # Default mask is None
    mask = None

    # Generate mask using BiRefNet if the user selected "Subject Only"
    if subject_choice == "Subject Only":
        mask = create_mask(original_image)

    # Apply pink filter based on user choice
    apply_to_subject = (subject_choice == "Subject Only" and mask is not None)
    result_image = apply_filter(original_image, mask, apply_to_subject)

    return result_image

# Define Gradio Interface
block = gr.Blocks()

with block:
    with gr.Row():
        gr.Markdown("Apply Pink Filter Effect to Subject or Full Image")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="numpy", label="Input Image", height=640)
            subject_choice = gr.Radio(
                choices=["Subject Only", "Full Image"], 
                value="Subject Only", 
                label="Apply Pink Filter to:"
            )
            run_button = gr.Button("Run")

        with gr.Column():
            output_image = gr.Image(label="Output Image")

    # Set the processing function
    run_button.click(
        fn=process, 
        inputs=[input_image, subject_choice], 
        outputs=output_image
    )

block.launch()