petergpt commited on
Commit
f397a20
Β·
verified Β·
1 Parent(s): c333b0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -35
app.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
7
 
8
- # Load model
9
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  birefnet.to(device)
@@ -19,72 +19,85 @@ transform_image = transforms.Compose([
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
- def process_batch(img_batch):
 
23
  inputs = []
24
  original_sizes = []
25
- for img in img_batch:
26
  original_sizes.append(img.size)
27
  inputs.append(transform_image(img))
28
  input_tensor = torch.stack(inputs).to(device)
29
 
 
30
  try:
31
  with torch.no_grad():
32
  preds = birefnet(input_tensor)[-1].sigmoid().cpu()
33
  except torch.OutOfMemoryError:
 
34
  torch.cuda.empty_cache()
35
  return None
36
 
 
37
  results = []
38
- for i, img in enumerate(img_batch):
39
  pred = preds[i].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(original_sizes[i])
42
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
43
  result.paste(img, mask=mask)
44
  results.append(result)
45
-
46
  return results
47
 
48
  def extract_objects(filepaths):
49
- # Open all images from the uploaded file paths
50
- images = [Image.open(path).convert("RGB") for path in filepaths]
51
-
52
- # You can define a batch size here (e.g., batch_size = 5)
53
- # This prevents trying to process all images at once if too large
54
- batch_size = 5
55
- batches = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
56
 
57
  total_start = time.time()
58
- all_results = []
59
- batch_times = []
60
- for b_idx, batch in enumerate(batches):
61
- b_start = time.time()
62
- res = process_batch(batch)
63
- if res is None:
64
- # Handle OOM gracefully
65
- all_results.extend([Image.new("RGBA", (100, 100), (255,0,0,255)) for _ in batch])
66
- batch_times.append(f"Batch {b_idx+1}: OOM Error")
 
 
 
 
 
 
 
 
 
 
 
67
  else:
68
- all_results.extend(res)
69
- b_end = time.time()
70
- batch_times.append(f"Batch {b_idx+1}: {(b_end - b_start):.2f} s")
71
- total_end = time.time()
72
 
73
- summary = (
74
- f"Total request time: {total_end - total_start:.2f} s\n"
75
- "Batch times:\n" + "\n".join(batch_times)
76
- )
77
 
78
- return all_results, summary
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  iface = gr.Interface(
81
  fn=extract_objects,
82
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
83
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
84
- title="BiRefNet Bulk Background Removal with Queue & Batch",
85
- description="Upload multiple images. The request is queued and processed in batches to avoid OOM errors."
86
  )
87
 
88
- # Enable the queue with defined concurrency to prevent multiple large requests at once
89
- # You can adjust concurrency_count and max_size as needed.
90
- iface.queue(concurrency_count=1, max_size=10).launch()
 
5
  from torchvision import transforms
6
  import gradio as gr
7
 
8
+ # Load the model
9
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  birefnet.to(device)
 
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
+ def try_inference(images):
23
+ # Convert images to tensors
24
  inputs = []
25
  original_sizes = []
26
+ for img in images:
27
  original_sizes.append(img.size)
28
  inputs.append(transform_image(img))
29
  input_tensor = torch.stack(inputs).to(device)
30
 
31
+ # Attempt inference
32
  try:
33
  with torch.no_grad():
34
  preds = birefnet(input_tensor)[-1].sigmoid().cpu()
35
  except torch.OutOfMemoryError:
36
+ # Clear CUDA cache and return None to indicate OOM
37
  torch.cuda.empty_cache()
38
  return None
39
 
40
+ # Post-process
41
  results = []
42
+ for i, img in enumerate(images):
43
  pred = preds[i].squeeze()
44
  pred_pil = transforms.ToPILImage()(pred)
45
  mask = pred_pil.resize(original_sizes[i])
46
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
47
  result.paste(img, mask=mask)
48
  results.append(result)
 
49
  return results
50
 
51
  def extract_objects(filepaths):
52
+ # Open all images
53
+ images = [Image.open(p).convert("RGB") for p in filepaths]
 
 
 
 
 
54
 
55
  total_start = time.time()
56
+ # If you have N images, start by trying them all
57
+ low = 1
58
+ high = len(images)
59
+ best = None
60
+ best_batch_size = 0
61
+
62
+ # Binary search to find max batch size that doesn't OOM
63
+ while low <= high:
64
+ mid = (low + high) // 2
65
+ batch_test = images[:mid]
66
+
67
+ start = time.time()
68
+ results = try_inference(batch_test)
69
+ end = time.time()
70
+
71
+ if results is not None:
72
+ # Succeeded with 'mid' batch size
73
+ best = results
74
+ best_batch_size = mid
75
+ low = mid + 1 # try a bigger batch
76
  else:
77
+ # OOM, try smaller batch
78
+ high = mid - 1
 
 
79
 
80
+ total_end = time.time()
 
 
 
81
 
82
+ if best is None:
83
+ # Even a single image caused OOM
84
+ summary = "Could not process even a single image without OOM."
85
+ return [], summary
86
+ else:
87
+ # Process the final chosen batch size fully (we already have results)
88
+ summary = (
89
+ f"Total request time: {total_end - total_start:.2f} s\n"
90
+ f"Successfully processed {best_batch_size} images in a single batch.\n"
91
+ f"Could not handle more than {best_batch_size} images without OOM."
92
+ )
93
+ return best, summary
94
 
95
  iface = gr.Interface(
96
  fn=extract_objects,
97
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
98
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
99
+ title="BiRefNet Dynamic Batch OOM Test",
100
+ description="Upload images. The system will try to process all at once, and if OOM occurs, it will try smaller batches automatically, quickly finding the largest feasible batch size."
101
  )
102
 
103
+ iface.launch()