petergpt commited on
Commit
b7a75e4
Β·
verified Β·
1 Parent(s): 10bba6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -1,13 +1,28 @@
1
  import time
2
  import torch
3
- from transformers import AutoModelForImageSegmentation
 
4
  from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
7
- import gc
8
 
9
  def load_model():
10
- model = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
  model.to(device)
13
  model.eval()
@@ -30,13 +45,17 @@ def run_inference(images, model, device):
30
  original_sizes.append(img.size)
31
  inputs.append(transform_image(img))
32
  input_tensor = torch.stack(inputs).to(device)
 
33
  try:
34
  with torch.no_grad():
 
 
35
  preds = model(input_tensor)[-1].sigmoid().cpu()
36
  except torch.OutOfMemoryError:
37
  del input_tensor
38
  torch.cuda.empty_cache()
39
  raise
 
40
  # Post-process
41
  results = []
42
  for i, img in enumerate(images):
@@ -46,6 +65,7 @@ def run_inference(images, model, device):
46
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
47
  result.paste(img, mask=mask)
48
  results.append(result)
 
49
  # Cleanup
50
  del input_tensor, preds
51
  gc.collect()
@@ -61,9 +81,8 @@ def binary_search_max(images):
61
  mid = (low + high) // 2
62
  batch = images[:mid]
63
  try:
64
- # Re-load model to avoid leftover memory fragmentation
65
  global birefnet, device
66
- birefnet, device = load_model()
67
  res = run_inference(batch, birefnet, device)
68
  best = res
69
  best_count = mid
@@ -84,7 +103,7 @@ def extract_objects(filepaths):
84
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
85
  return results, summary
86
  except torch.OutOfMemoryError:
87
- # OOM occurred, try to find feasible batch size now
88
  oom_time = time.time()
89
  initial_attempt_time = oom_time - start_time
90
 
@@ -114,7 +133,8 @@ iface = gr.Interface(
114
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
115
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
116
  title="BiRefNet Bulk Background Removal with On-Demand Fallback",
117
- description="Upload as many images as you want. If OOM occurs, a fallback will find the max feasible number. Extra cleanup steps and reinitialization for more consistent results."
118
  )
119
 
120
- iface.launch()
 
 
1
  import time
2
  import torch
3
+ import gc
4
+ from transformers import AutoConfig, AutoModelForImageSegmentation
5
  from PIL import Image
6
  from torchvision import transforms
7
  import gradio as gr
 
8
 
9
  def load_model():
10
+ # Fetch the config first (with trust_remote_code=True)
11
+ config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
12
+
13
+ # Ensure it's not treated as a seq2seq model
14
+ config.is_encoder_decoder = False
15
+
16
+ # Optionally, block calls to get_text_config if needed:
17
+ # config.get_text_config = lambda decoder=True: None
18
+
19
+ # Now load the model with our tweaked config
20
+ model = AutoModelForImageSegmentation.from_pretrained(
21
+ "zhengpeng7/BiRefNet_lite",
22
+ config=config,
23
+ trust_remote_code=True
24
+ )
25
+
26
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
  model.to(device)
28
  model.eval()
 
45
  original_sizes.append(img.size)
46
  inputs.append(transform_image(img))
47
  input_tensor = torch.stack(inputs).to(device)
48
+
49
  try:
50
  with torch.no_grad():
51
+ # If the last layer is returned as [-1],
52
+ # adjust accordingly or see how your model outputs are structured
53
  preds = model(input_tensor)[-1].sigmoid().cpu()
54
  except torch.OutOfMemoryError:
55
  del input_tensor
56
  torch.cuda.empty_cache()
57
  raise
58
+
59
  # Post-process
60
  results = []
61
  for i, img in enumerate(images):
 
65
  result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
66
  result.paste(img, mask=mask)
67
  results.append(result)
68
+
69
  # Cleanup
70
  del input_tensor, preds
71
  gc.collect()
 
81
  mid = (low + high) // 2
82
  batch = images[:mid]
83
  try:
 
84
  global birefnet, device
85
+ birefnet, device = load_model() # re-init to reduce memory fragmentation
86
  res = run_inference(batch, birefnet, device)
87
  best = res
88
  best_count = mid
 
103
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
104
  return results, summary
105
  except torch.OutOfMemoryError:
106
+ # OOM occurred, try fallback
107
  oom_time = time.time()
108
  initial_attempt_time = oom_time - start_time
109
 
 
133
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
134
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
135
  title="BiRefNet Bulk Background Removal with On-Demand Fallback",
136
+ description="Upload as many images as you want. If OOM occurs, fallback logic will find the max feasible number."
137
  )
138
 
139
+ if __name__ == "__main__":
140
+ iface.launch()