petergpt commited on
Commit
36a76ae
Β·
verified Β·
1 Parent(s): f397a20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -48
app.py CHANGED
@@ -19,8 +19,8 @@ transform_image = transforms.Compose([
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
- def try_inference(images):
23
- # Convert images to tensors
24
  inputs = []
25
  original_sizes = []
26
  for img in images:
@@ -28,14 +28,9 @@ def try_inference(images):
28
  inputs.append(transform_image(img))
29
  input_tensor = torch.stack(inputs).to(device)
30
 
31
- # Attempt inference
32
- try:
33
- with torch.no_grad():
34
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
35
- except torch.OutOfMemoryError:
36
- # Clear CUDA cache and return None to indicate OOM
37
- torch.cuda.empty_cache()
38
- return None
39
 
40
  # Post-process
41
  results = []
@@ -49,55 +44,66 @@ def try_inference(images):
49
  return results
50
 
51
  def extract_objects(filepaths):
52
- # Open all images
53
  images = [Image.open(p).convert("RGB") for p in filepaths]
 
54
 
55
- total_start = time.time()
56
- # If you have N images, start by trying them all
57
- low = 1
58
- high = len(images)
59
- best = None
60
- best_batch_size = 0
61
-
62
- # Binary search to find max batch size that doesn't OOM
63
- while low <= high:
64
- mid = (low + high) // 2
65
- batch_test = images[:mid]
66
 
67
- start = time.time()
68
- results = try_inference(batch_test)
69
- end = time.time()
70
 
71
- if results is not None:
72
- # Succeeded with 'mid' batch size
73
- best = results
74
- best_batch_size = mid
75
- low = mid + 1 # try a bigger batch
76
- else:
77
- # OOM, try smaller batch
78
- high = mid - 1
 
 
 
 
 
 
 
 
79
 
80
- total_end = time.time()
 
81
 
82
- if best is None:
83
- # Even a single image caused OOM
84
- summary = "Could not process even a single image without OOM."
85
- return [], summary
86
- else:
87
- # Process the final chosen batch size fully (we already have results)
88
- summary = (
89
- f"Total request time: {total_end - total_start:.2f} s\n"
90
- f"Successfully processed {best_batch_size} images in a single batch.\n"
91
- f"Could not handle more than {best_batch_size} images without OOM."
92
- )
93
- return best, summary
 
 
 
 
94
 
95
  iface = gr.Interface(
96
  fn=extract_objects,
97
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
98
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
99
- title="BiRefNet Dynamic Batch OOM Test",
100
- description="Upload images. The system will try to process all at once, and if OOM occurs, it will try smaller batches automatically, quickly finding the largest feasible batch size."
101
  )
102
 
103
  iface.launch()
 
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
+ def run_inference(images):
23
+ # Convert all images into a batch tensor
24
  inputs = []
25
  original_sizes = []
26
  for img in images:
 
28
  inputs.append(transform_image(img))
29
  input_tensor = torch.stack(inputs).to(device)
30
 
31
+ # Run inference
32
+ with torch.no_grad():
33
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
 
 
 
 
 
34
 
35
  # Post-process
36
  results = []
 
44
  return results
45
 
46
  def extract_objects(filepaths):
 
47
  images = [Image.open(p).convert("RGB") for p in filepaths]
48
+ start_time = time.time()
49
 
50
+ # Attempt to process all at once
51
+ try:
52
+ results = run_inference(images)
53
+ end_time = time.time()
54
+ total_time = end_time - start_time
55
+ summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
56
+ return results, summary
57
+ except torch.OutOfMemoryError:
58
+ # Only if we fail, do we attempt to find a feasible batch size
59
+ torch.cuda.empty_cache()
 
60
 
61
+ fail_time = time.time()
62
+ initial_attempt_time = fail_time - start_time
 
63
 
64
+ # Binary search to find max feasible batch size
65
+ low, high = 1, len(images)
66
+ best = None
67
+ best_count = 0
68
+
69
+ while low <= high:
70
+ mid = (low + high) // 2
71
+ batch = images[:mid]
72
+ try:
73
+ res = run_inference(batch)
74
+ best = res
75
+ best_count = mid
76
+ low = mid + 1 # try bigger
77
+ except torch.OutOfMemoryError:
78
+ torch.cuda.empty_cache()
79
+ high = mid - 1 # try smaller
80
 
81
+ end_time = time.time()
82
+ total_time = end_time - start_time
83
 
84
+ if best is None:
85
+ # Not even 1 image works
86
+ summary = (
87
+ f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
88
+ f"Could not process even a single image.\n"
89
+ f"Total time with fallback attempts: {total_time:.2f}s."
90
+ )
91
+ return [], summary
92
+ else:
93
+ summary = (
94
+ f"Initial attempt OOM after {initial_attempt_time:.2f}s. "
95
+ f"After fallback tests, found that {best_count} images can be processed.\n"
96
+ f"Total time including fallback: {total_time:.2f}s.\n"
97
+ f"Next time, try using up to {best_count} images."
98
+ )
99
+ return best, summary
100
 
101
  iface = gr.Interface(
102
  fn=extract_objects,
103
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
104
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
105
+ title="BiRefNet Bulk Background Removal with On-Demand Fallback",
106
+ description="Upload as many images as you want. If OOM occurs, a quick fallback will find the max feasible number of images without adding overhead unless needed."
107
  )
108
 
109
  iface.launch()