Realcat's picture
update: major change
499e141
import cv2
import csv
import torch
import numpy as np
CLS_DICT = {}
with open('weights/object150_info.csv') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
name = row[5].split(";")[0]
if name == 'screen':
name = '_'.join(row[5].split(";")[:2])
CLS_DICT[name] = int(row[0]) - 1
exclude = ['person', 'sky', 'car']
def read_deeplab_image(img, size):
width, height = img.shape[1], img.shape[0]
if max(width, height) > size:
if width > height:
img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA)
else:
img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA)
img = (torch.from_numpy(img.copy()).float() / 255).permute(2, 0, 1)[None]
return img
def read_segmentation_image(img, size):
img = read_deeplab_image(img, size=size)[0]
# img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)
img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
return img
def segment(rgb, size, device, segmentation_module):
img_data = read_segmentation_image(rgb, size=size)
singleton_batch = {'img_data': img_data[None].to(device)}
output_size = img_data.shape[1:]
# Run the segmentation at the highest resolution.
scores = segmentation_module(singleton_batch, segSize=output_size)
# Get the predicted scores for each pixel
_, pred = torch.max(scores, dim=1)
return pred.cpu()[0].numpy().astype(np.uint8)