bthndmn12 commited on
Commit
4fc2180
1 Parent(s): fba9efa

fixed some bugs

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -52,14 +52,13 @@ def greet(image):
52
  with torch.no_grad():
53
  outputs = model(**inputs,multimask_outputs=False)
54
 
55
- # outputs = outputs.logits[0].cpu().numpy()
56
- # outputs = np.argmax(outputs, axis=0)
57
- # outputs = Image.fromarray(outputs)
58
- # return outputs
59
  seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
60
  seg_prob = seg_prob.cpu().numpy().squeeze()
61
  seg_prob = (seg_prob > 0.5).astype(np.uint8)
62
 
 
 
 
63
  return seg_prob
64
 
65
 
 
52
  with torch.no_grad():
53
  outputs = model(**inputs,multimask_outputs=False)
54
 
 
 
 
 
55
  seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
56
  seg_prob = seg_prob.cpu().numpy().squeeze()
57
  seg_prob = (seg_prob > 0.5).astype(np.uint8)
58
 
59
+ # Add an extra dimension to seg_prob
60
+ seg_prob = np.expand_dims(seg_prob, axis=-1)
61
+
62
  return seg_prob
63
 
64