File size: 5,177 Bytes
c9473c9 b7a75e4 de63122 623e1bf de63122 4d7e87d b7a75e4 de63122 b7a75e4 4d7e87d de63122 4d7e87d 623e1bf de63122 623e1bf de63122 623e1bf 4d7e87d c9473c9 f397a20 c9473c9 b7a75e4 de63122 4d7e87d de63122 4d7e87d b7a75e4 f397a20 c9473c9 f397a20 c9473c9 b7a75e4 4d7e87d de63122 c333b0b 4d7e87d de63122 4d7e87d de63122 4d7e87d de63122 4d7e87d de63122 4d7e87d de63122 4d7e87d c333b0b f397a20 36a76ae c333b0b de63122 36a76ae 4d7e87d 36a76ae de63122 36a76ae de63122 4d7e87d 36a76ae 4d7e87d 36a76ae 1018e38 36a76ae de63122 36a76ae 4d7e87d 36a76ae 4d7e87d 36a76ae 623e1bf c9473c9 e041428 c9473c9 de63122 623e1bf b7a75e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import time
import gc
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
from transformers import AutoConfig, AutoModelForImageSegmentation
# 1) Wrap config loading in a helper that monkey-patches a dummy get_text_config().
def load_model():
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
config.is_encoder_decoder = False
# We define a dummy function that returns a minimal object
# with a tie_word_embeddings attribute, so tie_weights() won't fail.
def dummy_text_config(decoder=True):
class DummyTextConfig:
tie_word_embeddings = False
return DummyTextConfig()
# Patch the config so huggingface code won't blow up
setattr(config, "get_text_config", dummy_text_config)
model = AutoModelForImageSegmentation.from_pretrained(
"zhengpeng7/BiRefNet_lite",
config=config,
trust_remote_code=True
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
return model, device
# 2) Initialize global model & device
birefnet, device = load_model()
# 3) Preprocessing transform
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def run_inference(images, model, device):
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
try:
with torch.no_grad():
# If the model returns multiple outputs, adapt as needed
output = model(input_tensor)
# The last element might be your segmentation mask. Adjust if needed:
# e.g. preds = output[-1] if it returns a list/tuple
# or preds = output.logits if it returns a named field
# The original example used `output[-1].sigmoid()`, so:
preds = output[-1].sigmoid().cpu()
except torch.OutOfMemoryError:
del input_tensor
torch.cuda.empty_cache()
raise
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
# Cleanup
del input_tensor, preds
gc.collect()
torch.cuda.empty_cache()
return results
def binary_search_max(images):
low, high = 1, len(images)
best, best_count = None, 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
# Re-load the model to avoid leftover memory fragmentation
global birefnet, device
birefnet, device = load_model()
res = run_inference(batch, birefnet, device)
best, best_count = res, mid
low = mid + 1
except torch.OutOfMemoryError:
high = mid - 1
return best, best_count
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# First attempt: all images at once
try:
results = run_inference(images, birefnet, device)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# If it fails with OOM, do a fallback
oom_time = time.time()
initial_attempt_time = oom_time - start_time
best, best_count = binary_search_max(images)
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image can be processed
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time including fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Found that {best_count} images can be processed without OOM.\n"
f"Total time including fallback attempts: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal (with fallback)",
description="Upload multiple images. If OOM occurs, we fallback to smaller batches."
)
if __name__ == "__main__":
iface.launch()
|