bgremoval / app.py
petergpt's picture
Update app.py
4d7e87d verified
raw
history blame
4.17 kB
import time
import torch
from transformers import AutoModelForImageSegmentation
from PIL import Image
from torchvision import transforms
import gradio as gr
import gc
def load_model():
model = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
return model, device
birefnet, device = load_model()
# Preprocessing
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference(images, model, device):
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
try:
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
except torch.OutOfMemoryError:
del input_tensor
torch.cuda.empty_cache()
raise
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
# Cleanup
del input_tensor, preds
gc.collect()
torch.cuda.empty_cache()
return results
def binary_search_max(images):
# After OOM, try to find max feasible batch
low, high = 1, len(images)
best = None
best_count = 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
# Re-load model to avoid leftover memory fragmentation
global birefnet, device
birefnet, device = load_model()
res = run_inference(batch, birefnet, device)
best = res
best_count = mid
low = mid + 1
except torch.OutOfMemoryError:
high = mid - 1
return best, best_count
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# First attempt: all images
try:
results = run_inference(images, birefnet, device)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# OOM occurred, try to find feasible batch size now
oom_time = time.time()
initial_attempt_time = oom_time - start_time
best, best_count = binary_search_max(images)
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image works
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time including fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Found that {best_count} images can be processed without OOM.\n"
f"Total time including fallback attempts: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal with On-Demand Fallback",
description="Upload as many images as you want. If OOM occurs, a fallback will find the max feasible number. Extra cleanup steps and reinitialization for more consistent results."
)
iface.launch()