File size: 3,802 Bytes
c9473c9
623e1bf
 
 
 
 
 
f397a20
623e1bf
 
 
 
 
c333b0b
623e1bf
 
 
 
 
 
 
36a76ae
 
c9473c9
 
f397a20
c9473c9
 
 
 
36a76ae
 
 
c9473c9
f397a20
c9473c9
f397a20
c9473c9
 
 
 
 
 
c333b0b
 
 
f397a20
36a76ae
c333b0b
36a76ae
 
 
 
 
 
 
 
 
 
f397a20
36a76ae
 
f397a20
36a76ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9473c9
36a76ae
 
1018e38
36a76ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623e1bf
 
c9473c9
e041428
c9473c9
36a76ae
 
623e1bf
 
f397a20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import time
import torch
from transformers import AutoModelForImageSegmentation
from PIL import Image
from torchvision import transforms
import gradio as gr

# Load the model
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
birefnet.to(device)
birefnet.eval()

# Preprocessing
image_size = (1024, 1024)
transform_image = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def run_inference(images):
    # Convert all images into a batch tensor
    inputs = []
    original_sizes = []
    for img in images:
        original_sizes.append(img.size)
        inputs.append(transform_image(img))
    input_tensor = torch.stack(inputs).to(device)

    # Run inference
    with torch.no_grad():
        preds = birefnet(input_tensor)[-1].sigmoid().cpu()

    # Post-process
    results = []
    for i, img in enumerate(images):
        pred = preds[i].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(original_sizes[i])
        result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
        result.paste(img, mask=mask)
        results.append(result)
    return results

def extract_objects(filepaths):
    images = [Image.open(p).convert("RGB") for p in filepaths]
    start_time = time.time()

    # Attempt to process all at once
    try:
        results = run_inference(images)
        end_time = time.time()
        total_time = end_time - start_time
        summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
        return results, summary
    except torch.OutOfMemoryError:
        # Only if we fail, do we attempt to find a feasible batch size
        torch.cuda.empty_cache()

        fail_time = time.time()
        initial_attempt_time = fail_time - start_time

        # Binary search to find max feasible batch size
        low, high = 1, len(images)
        best = None
        best_count = 0
        
        while low <= high:
            mid = (low + high) // 2
            batch = images[:mid]
            try:
                res = run_inference(batch)
                best = res
                best_count = mid
                low = mid + 1  # try bigger
            except torch.OutOfMemoryError:
                torch.cuda.empty_cache()
                high = mid - 1  # try smaller

        end_time = time.time()
        total_time = end_time - start_time

        if best is None:
            # Not even 1 image works
            summary = (
                f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
                f"Could not process even a single image.\n"
                f"Total time with fallback attempts: {total_time:.2f}s."
            )
            return [], summary
        else:
            summary = (
                f"Initial attempt OOM after {initial_attempt_time:.2f}s. "
                f"After fallback tests, found that {best_count} images can be processed.\n"
                f"Total time including fallback: {total_time:.2f}s.\n"
                f"Next time, try using up to {best_count} images."
            )
            return best, summary

iface = gr.Interface(
    fn=extract_objects,
    inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
    outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
    title="BiRefNet Bulk Background Removal with On-Demand Fallback",
    description="Upload as many images as you want. If OOM occurs, a quick fallback will find the max feasible number of images without adding overhead unless needed."
)

iface.launch()