user-agent commited on
Commit
d2a1709
·
verified ·
1 Parent(s): 1e65fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -6,22 +6,23 @@ from PIL import Image
6
  from torchvision import transforms
7
  from transformers import AutoModelForImageSegmentation
8
  from typing import Union, List
9
- from loadimg import load_img # Your helper to load from URL or file
10
 
11
  torch.set_float32_matmul_precision("high")
12
 
13
- # Load BiRefNet model
14
- birefnet = AutoModelForImageSegmentation.from_pretrained(
15
- "ZhengPeng7/BiRefNet", trust_remote_code=True
 
16
  )
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- birefnet.to(device)
19
 
20
- # Image transformation
21
  transform_image = transforms.Compose([
22
  transforms.Resize((1024, 1024)),
23
  transforms.ToTensor(),
24
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
  ])
26
 
27
  @spaces.GPU
@@ -30,22 +31,33 @@ def process(image: Image.Image) -> Image.Image:
30
  input_tensor = transform_image(image).unsqueeze(0).to(device)
31
 
32
  with torch.no_grad():
33
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
 
 
 
 
 
 
 
 
34
 
35
- pred = preds[0].squeeze()
36
- mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
37
- binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
 
 
38
 
 
39
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
40
  result = Image.composite(image, white_bg, binary_mask)
41
  return result
42
 
 
43
  @spaces.GPU
44
  def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
45
  results = []
46
 
47
  try:
48
- # Single image upload
49
  if image is not None:
50
  image = image.convert("RGB")
51
  processed = process(image)
@@ -53,7 +65,6 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
53
  processed.save(filename)
54
  return filename
55
 
56
- # Single image from URL
57
  if image_url:
58
  im = load_img(image_url, output_type="pil").convert("RGB")
59
  processed = process(im)
@@ -61,7 +72,6 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
61
  processed.save(filename)
62
  return filename
63
 
64
- # Batch of URLs
65
  if batch_urls:
66
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
67
  for url in urls:
@@ -80,7 +90,6 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
80
 
81
  return None
82
 
83
- # Interface
84
  demo = gr.Interface(
85
  fn=handler,
86
  inputs=[
@@ -89,9 +98,9 @@ demo = gr.Interface(
89
  gr.Textbox(label="Comma-separated Image URLs (Batch)"),
90
  ],
91
  outputs=gr.File(label="Output File(s)", file_count="multiple"),
92
- title="Background Remover (White Fill)",
93
  description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
94
  )
95
 
96
  if __name__ == "__main__":
97
- demo.launch(show_error=True, mcp_server=True)
 
6
  from torchvision import transforms
7
  from transformers import AutoModelForImageSegmentation
8
  from typing import Union, List
9
+ from loadimg import load_img
10
 
11
  torch.set_float32_matmul_precision("high")
12
 
13
+ # Load RMBG v1.4 model
14
+ model = AutoModelForImageSegmentation.from_pretrained(
15
+ "briaai/RMBG-1.4",
16
+ trust_remote_code=True
17
  )
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
 
21
+ # RMBG v1.4 uses different preprocessing
22
  transform_image = transforms.Compose([
23
  transforms.Resize((1024, 1024)),
24
  transforms.ToTensor(),
25
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
  ])
27
 
28
  @spaces.GPU
 
31
  input_tensor = transform_image(image).unsqueeze(0).to(device)
32
 
33
  with torch.no_grad():
34
+ # RMBG v1.4 returns the mask directly
35
+ preds = model(input_tensor)
36
+ # Get the mask - RMBG returns different structure than BiRefNet
37
+ if isinstance(preds, list):
38
+ pred = preds[-1]
39
+ else:
40
+ pred = preds
41
+
42
+ pred = pred.sigmoid().cpu()
43
 
44
+ mask = pred[0].squeeze()
45
+ mask_pil = transforms.ToPILImage()(mask).resize(image_size).convert("L")
46
+
47
+ # Create binary mask
48
+ binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
49
 
50
+ # Apply mask with white background
51
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
52
  result = Image.composite(image, white_bg, binary_mask)
53
  return result
54
 
55
+ # Rest of your code remains the same...
56
  @spaces.GPU
57
  def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
58
  results = []
59
 
60
  try:
 
61
  if image is not None:
62
  image = image.convert("RGB")
63
  processed = process(image)
 
65
  processed.save(filename)
66
  return filename
67
 
 
68
  if image_url:
69
  im = load_img(image_url, output_type="pil").convert("RGB")
70
  processed = process(im)
 
72
  processed.save(filename)
73
  return filename
74
 
 
75
  if batch_urls:
76
  urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
77
  for url in urls:
 
90
 
91
  return None
92
 
 
93
  demo = gr.Interface(
94
  fn=handler,
95
  inputs=[
 
98
  gr.Textbox(label="Comma-separated Image URLs (Batch)"),
99
  ],
100
  outputs=gr.File(label="Output File(s)", file_count="multiple"),
101
+ title="Background Remover (RMBG v1.4)",
102
  description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
103
  )
104
 
105
  if __name__ == "__main__":
106
+ demo.launch(show_error=True, mcp_server=True)