holiday_cards / app.py
Amit Gazal
wip
e25fc54
raw
history blame
2.24 kB
import gradio as gr
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
if torch.cuda.is_available():
model = model.to('cuda')
model.eval()
def remove_background(input_image, holiday, message):
image_size = (1024, 1024)
# Transform the input image
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Process the image
input_tensor = transform_image(input_image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.to('cuda')
# Generate prediction
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
# Create image without background
result_image = input_image.copy()
result_image.putalpha(mask)
# Create image with only background
only_background_image = input_image.copy()
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
only_background_image.putalpha(inverted_mask)
first_output_image = result_image
second_output_image = only_background_image
third_output_image = result_image
return first_output_image, second_output_image, third_output_image
# Replace the demo interface
demo = gr.Interface(
fn=remove_background,
inputs=[
gr.Image(type="pil"),
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
],
outputs=[
gr.Image(type="pil", label="First Output"),
gr.Image(type="pil", label="Second Output"),
gr.Image(type="pil", label="Third Output")
],
title="Holiday Card Generator",
description="Upload an image to generate a holiday card"
)
demo.launch()