import gradio as gr from transformers import pipeline import torch import numpy as np from PIL import Image import io import base64 def remove_background(input_image): try: # Initialize the pipeline segmentor = pipeline("image-segmentation", model="briaai/RMBG-1.4", device=-1) # CPU inference # Process the image result = segmentor(input_image) # Return both original and processed images return result['output_image'] except Exception as e: raise gr.Error(f"Error processing image: {str(e)}") # Create the Gradio interface css = """ .gradio-container { font-family: 'Segoe UI', sans-serif; background: linear-gradient(135deg, #1a1a1a 0%, #2d2d2d 100%); } .gr-button { background: linear-gradient(45deg, #FFD700, #FFA500); border: none; color: black; } .gr-button:hover { background: linear-gradient(45deg, #FFA500, #FFD700); transform: translateY(-2px); box-shadow: 0 4px 12px rgba(255, 215, 0, 0.3); } .gr-form { background: rgba(255, 255, 255, 0.1); border-radius: 16px; padding: 20px; box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2); } .gr-image { border-radius: 12px; border: 2px solid rgba(255, 215, 0, 0.3); } """ with gr.Blocks(css=css) as demo: gr.HTML( """

Background Removal Tool

Powered by RMBG V1.4 model from BRIA AI

""" ) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Image", type="pil", tool="upload", ) with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Remove Background", variant="primary") with gr.Column(): output_image = gr.Image( label="Result", type="pil", ) download_btn = gr.Button("Download Result") # Event handlers submit_btn.click( fn=remove_background, inputs=[input_image], outputs=[output_image], ) clear_btn.click( lambda: (None, None), outputs=[input_image, output_image], ) # Example images gr.Examples( examples=[ ["example1.jpg"], ["example2.jpg"], ["example3.jpg"], ], inputs=input_image, outputs=output_image, fn=remove_background, cache_examples=True, ) # Download functionality download_btn.click( lambda x: x, inputs=[output_image], outputs=[output_image], ) demo.launch()