Spaces:
Runtime error
Runtime error
File size: 3,882 Bytes
7cddaa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import gradio as gr
import cv2
from PIL import Image
import numpy as np
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import spaces # Import ZeroGPU support
# Detect if CUDA is available; otherwise, fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load BiRefNet model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)
# Image transformation pipeline
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(duration=70) # Decorate to ensure GPU is allocated only during model loading
# Function to extract the subject using BiRefNet and create a mask
def create_mask(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu() # Always move results to CPU for processing
pred = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(pred)
mask = mask_pil.resize(image_size)
return mask
# Function to apply the pink filter-like color change
def apply_filter(image, mask=None, apply_to_subject=True):
# Convert image to numpy array
image_np = np.array(image.convert("RGBA"))
# Define the pink color in RGBA
pink_color = np.array([255, 0, 255, 128]) # Pink color with transparency (alpha = 128)
if apply_to_subject and mask is not None:
# Convert mask to numpy array
mask_np = np.array(mask)
# Blend the original image with the pink color where the mask is applied
for i in range(image_np.shape[0]):
for j in range(image_np.shape[1]):
if mask_np[i, j] > 128: # Check if the mask value indicates subject presence
image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8)
else:
# Apply the pink filter to the whole image if no subject is detected or if chosen by user
image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8)
# Convert back to PIL image
result_image = Image.fromarray(image_np)
return result_image
# Main processing function for Gradio
def process(input_image, subject_choice):
if input_image is None:
raise gr.Error('Please upload an input image')
# Convert input image to PIL image
original_image = Image.fromarray(input_image)
# Default mask is None
mask = None
# Generate mask using BiRefNet if the user selected "Subject Only"
if subject_choice == "Subject Only":
mask = create_mask(original_image)
# Apply pink filter based on user choice
apply_to_subject = (subject_choice == "Subject Only" and mask is not None)
result_image = apply_filter(original_image, mask, apply_to_subject)
return result_image
# Define Gradio Interface
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("Apply Pink Filter Effect to Subject or Full Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image", height=640)
subject_choice = gr.Radio(
choices=["Subject Only", "Full Image"],
value="Subject Only",
label="Apply Pink Filter to:"
)
run_button = gr.Button("Run")
with gr.Column():
output_image = gr.Image(label="Output Image")
# Set the processing function
run_button.click(
fn=process,
inputs=[input_image, subject_choice],
outputs=output_image
)
block.launch() |