user-agent commited on
Commit
461565d
·
verified ·
1 Parent(s): d2a1709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -18,7 +18,7 @@ model = AutoModelForImageSegmentation.from_pretrained(
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
 
21
- # RMBG v1.4 uses different preprocessing
22
  transform_image = transforms.Compose([
23
  transforms.Resize((1024, 1024)),
24
  transforms.ToTensor(),
@@ -31,20 +31,25 @@ def process(image: Image.Image) -> Image.Image:
31
  input_tensor = transform_image(image).unsqueeze(0).to(device)
32
 
33
  with torch.no_grad():
34
- # RMBG v1.4 returns the mask directly
35
  preds = model(input_tensor)
36
- # Get the mask - RMBG returns different structure than BiRefNet
37
- if isinstance(preds, list):
38
- pred = preds[-1]
 
 
 
39
  else:
40
  pred = preds
41
-
42
- pred = pred.sigmoid().cpu()
 
43
 
44
- mask = pred[0].squeeze()
45
- mask_pil = transforms.ToPILImage()(mask).resize(image_size).convert("L")
 
46
 
47
- # Create binary mask
48
  binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
49
 
50
  # Apply mask with white background
@@ -52,7 +57,6 @@ def process(image: Image.Image) -> Image.Image:
52
  result = Image.composite(image, white_bg, binary_mask)
53
  return result
54
 
55
- # Rest of your code remains the same...
56
  @spaces.GPU
57
  def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
58
  results = []
@@ -87,6 +91,9 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
87
 
88
  except Exception as e:
89
  print("General error:", e)
 
 
 
90
 
91
  return None
92
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
 
21
+ # Transform for RMBG v1.4
22
  transform_image = transforms.Compose([
23
  transforms.Resize((1024, 1024)),
24
  transforms.ToTensor(),
 
31
  input_tensor = transform_image(image).unsqueeze(0).to(device)
32
 
33
  with torch.no_grad():
34
+ # RMBG v1.4 returns a tuple, we need the first element
35
  preds = model(input_tensor)
36
+
37
+ # Handle different return types
38
+ if isinstance(preds, tuple):
39
+ pred = preds[0] # Take first element if tuple
40
+ elif isinstance(preds, list):
41
+ pred = preds[-1] # Take last element if list
42
  else:
43
  pred = preds
44
+
45
+ # Apply sigmoid and move to CPU
46
+ mask = pred.sigmoid().cpu()
47
 
48
+ # Process the mask
49
+ mask_tensor = mask[0].squeeze()
50
+ mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
51
 
52
+ # Create binary mask with threshold
53
  binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
54
 
55
  # Apply mask with white background
 
57
  result = Image.composite(image, white_bg, binary_mask)
58
  return result
59
 
 
60
  @spaces.GPU
61
  def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
62
  results = []
 
91
 
92
  except Exception as e:
93
  print("General error:", e)
94
+ # Add debug info
95
+ import traceback
96
+ traceback.print_exc()
97
 
98
  return None
99