File size: 3,802 Bytes
c9473c9 623e1bf f397a20 623e1bf c333b0b 623e1bf 36a76ae c9473c9 f397a20 c9473c9 36a76ae c9473c9 f397a20 c9473c9 f397a20 c9473c9 c333b0b f397a20 36a76ae c333b0b 36a76ae f397a20 36a76ae f397a20 36a76ae c9473c9 36a76ae 1018e38 36a76ae 623e1bf c9473c9 e041428 c9473c9 36a76ae 623e1bf f397a20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import time
import torch
from transformers import AutoModelForImageSegmentation
from PIL import Image
from torchvision import transforms
import gradio as gr
# Load the model
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
birefnet.to(device)
birefnet.eval()
# Preprocessing
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference(images):
# Convert all images into a batch tensor
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
# Run inference
with torch.no_grad():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
return results
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# Attempt to process all at once
try:
results = run_inference(images)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# Only if we fail, do we attempt to find a feasible batch size
torch.cuda.empty_cache()
fail_time = time.time()
initial_attempt_time = fail_time - start_time
# Binary search to find max feasible batch size
low, high = 1, len(images)
best = None
best_count = 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
res = run_inference(batch)
best = res
best_count = mid
low = mid + 1 # try bigger
except torch.OutOfMemoryError:
torch.cuda.empty_cache()
high = mid - 1 # try smaller
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image works
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time with fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s. "
f"After fallback tests, found that {best_count} images can be processed.\n"
f"Total time including fallback: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal with On-Demand Fallback",
description="Upload as many images as you want. If OOM occurs, a quick fallback will find the max feasible number of images without adding overhead unless needed."
)
iface.launch()
|