import gradio as gr from PIL import Image import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) if torch.cuda.is_available(): model = model.to('cuda') model.eval() def remove_background(input_image, holiday, message): image_size = (1024, 1024) # Transform the input image transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Process the image input_tensor = transform_image(input_image).unsqueeze(0) if torch.cuda.is_available(): input_tensor = input_tensor.to('cuda') # Generate prediction with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(input_image.size) # Create image without background result_image = input_image.copy() result_image.putalpha(mask) # Create image with only background only_background_image = input_image.copy() inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask only_background_image.putalpha(inverted_mask) first_output_image = result_image second_output_image = only_background_image third_output_image = result_image return first_output_image, second_output_image, third_output_image # Replace the demo interface demo = gr.Interface( fn=remove_background, inputs=[ gr.Image(type="pil"), gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"), gr.Text(label="Optional Message", placeholder="Enter your holiday message here...") ], outputs=[ gr.Image(type="pil", label="First Output"), gr.Image(type="pil", label="Second Output"), gr.Image(type="pil", label="Third Output") ], title="Holiday Card Generator", description="Upload an image to generate a holiday card" ) demo.launch()