File size: 1,605 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import cv2
import csv
import torch
import numpy as np

CLS_DICT = {}
with open('weights/object150_info.csv') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        name = row[5].split(";")[0]
        if name == 'screen':
            name = '_'.join(row[5].split(";")[:2])
        CLS_DICT[name] = int(row[0]) - 1

exclude = ['person', 'sky', 'car']


def read_deeplab_image(img, size):
    width, height = img.shape[1], img.shape[0]

    if max(width, height) > size:
        if width > height:
            img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA)
        else:
            img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA)

    img = (torch.from_numpy(img.copy()).float() / 255).permute(2, 0, 1)[None]

    return img


def read_segmentation_image(img, size):
    img = read_deeplab_image(img, size=size)[0]
    # img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)
    img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    return img


def segment(rgb, size, device, segmentation_module):
    img_data = read_segmentation_image(rgb, size=size)
    singleton_batch = {'img_data': img_data[None].to(device)}
    output_size = img_data.shape[1:]
    # Run the segmentation at the highest resolution.
    scores = segmentation_module(singleton_batch, segSize=output_size)
    # Get the predicted scores for each pixel
    _, pred = torch.max(scores, dim=1)
    return pred.cpu()[0].numpy().astype(np.uint8)