bthndmn12 commited on
Commit
08eeae0
1 Parent(s): 4fc2180

fixed some bugs

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -50,16 +50,23 @@ def greet(image):
50
 
51
  model.eval()
52
  with torch.no_grad():
53
- outputs = model(**inputs,multimask_outputs=False)
54
 
55
  seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
56
  seg_prob = seg_prob.cpu().numpy().squeeze()
57
  seg_prob = (seg_prob > 0.5).astype(np.uint8)
58
 
59
- # Add an extra dimension to seg_prob
60
- seg_prob = np.expand_dims(seg_prob, axis=-1)
 
 
 
 
 
 
 
 
61
 
62
- return seg_prob
63
 
64
 
65
  iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
 
50
 
51
  model.eval()
52
  with torch.no_grad():
53
+ outputs = model(**inputs, multimask_outputs=False)
54
 
55
  seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
56
  seg_prob = seg_prob.cpu().numpy().squeeze()
57
  seg_prob = (seg_prob > 0.5).astype(np.uint8)
58
 
59
+ # Ensure the array is 2D (height, width) for grayscale image
60
+ if seg_prob.ndim > 2:
61
+ seg_prob = seg_prob.squeeze() # Remove extra dimensions if any
62
+ elif seg_prob.ndim < 2:
63
+ raise ValueError("Output mask has less than 2 dimensions")
64
+
65
+ # Convert the processed mask back to a PIL image
66
+ seg_prob_image = Image.fromarray(seg_prob)
67
+
68
+ return seg_prob_image
69
 
 
70
 
71
 
72
  iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")