Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
torch.set_float32_matmul_precision(['high', 'highest'][0]) | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
model.eval() | |
def remove_background(input_image, holiday, message): | |
image_size = (1024, 1024) | |
# Transform the input image | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Process the image | |
input_tensor = transform_image(input_image).unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_tensor = input_tensor.to('cuda') | |
# Generate prediction | |
with torch.no_grad(): | |
preds = model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(input_image.size) | |
# Create image without background | |
result_image = input_image.copy() | |
result_image.putalpha(mask) | |
# Create image with only background | |
only_background_image = input_image.copy() | |
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask | |
only_background_image.putalpha(inverted_mask) | |
first_output_image = result_image | |
second_output_image = only_background_image | |
third_output_image = result_image | |
return first_output_image, second_output_image, third_output_image | |
# Replace the demo interface | |
demo = gr.Interface( | |
fn=remove_background, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"), | |
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="First Output"), | |
gr.Image(type="pil", label="Second Output"), | |
gr.Image(type="pil", label="Third Output") | |
], | |
title="Holiday Card Generator", | |
description="Upload an image to generate a holiday card" | |
) | |
demo.launch() | |