MWilinski commited on
Commit
755362b
·
1 Parent(s): 84bde76

feat: working demo

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -3,13 +3,33 @@ import cv2
3
  from PIL import Image
4
  import numpy as np
5
  from gradio import components
 
 
 
 
 
 
 
 
6
 
 
 
7
 
8
- def segment_and_show(image):
9
- image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
10
 
11
- # TODO: Implement segmentation logic here
12
- return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  default_image = Image.open("demo.jpeg")
 
3
  from PIL import Image
4
  import numpy as np
5
  from gradio import components
6
+ import torchvision
7
+ from torchvision.models.detection import (
8
+ maskrcnn_resnet50_fpn,
9
+ MaskRCNN_ResNet50_FPN_Weights,
10
+ )
11
+ import torchvision.transforms.functional as F
12
+ import torch
13
+ from torchvision.utils import draw_segmentation_masks
14
 
15
+ weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
16
+ transforms = weights.transforms()
17
 
18
+ model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
19
+ model = model.eval()
20
 
21
+
22
+ def segment_and_show(image):
23
+ input_image = Image.fromarray(image)
24
+ input_tensor = torch.tensor(np.array(input_image))
25
+ input_tensor = input_tensor.permute(2, 0, 1)
26
+ input_image = transforms(input_image)
27
+ output = model([input_image])[0]
28
+ proba_threshold = 0.5
29
+ masks = output["masks"] > proba_threshold
30
+ masks = masks.squeeze(1)
31
+ image_with_segmasks = draw_segmentation_masks(input_tensor, masks, alpha=0.7)
32
+ return np.array(F.to_pil_image(image_with_segmasks))
33
 
34
 
35
  default_image = Image.open("demo.jpeg")