import torch from transformers import AutoModelForImageSegmentation from PIL import Image from torchvision import transforms import gradio as gr # Load the model from Hugging Face birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' birefnet.to(device) birefnet.eval() # Define the transform to preprocess the input image image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def extract_object(image): input_images = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image_with_alpha = image.convert("RGBA") image_with_alpha.putalpha(mask) return image_with_alpha iface = gr.Interface( fn=extract_object, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Image(type="pil", label="Segmented Image"), title="BiRefNet Background Removal", description="Upload an image and get the foreground object extracted." ) if __name__ == "__main__": iface.launch()