import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import cv2 from basicsr.utils import img2tensor, tensor2img _BATCH_NORM = nn.BatchNorm2d _BOTTLENECK_EXPANSION = 4 import blobfile as bf def _list_image_files_recursively(data_dir): results = [] for entry in sorted(bf.listdir(data_dir)): full_path = bf.join(data_dir, entry) ext = entry.split(".")[-1] if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: results.append(full_path) elif bf.isdir(full_path): results.extend(_list_image_files_recursively(full_path)) return results def uint82bin(n, count=8): """returns the binary of integer n, count refers to amount of bits""" return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) def labelcolormap(N): if N == 35: # cityscape cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], dtype=np.uint8) else: cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r, g, b = 0, 0, 0 id = i + 1 # let's give 0 a color for j in range(7): str_id = uint82bin(id) r = r ^ (np.uint8(str_id[-1]) << (7 - j)) g = g ^ (np.uint8(str_id[-2]) << (7 - j)) b = b ^ (np.uint8(str_id[-3]) << (7 - j)) id = id >> 3 cmap[i, 0] = r cmap[i, 1] = g cmap[i, 2] = b return cmap class Colorize(object): def __init__(self, n=182): self.cmap = labelcolormap(n) def __call__(self, gray_image): size = gray_image.shape color_image = np.zeros((3, size[0], size[1])) for label in range(0, len(self.cmap)): mask = (label == gray_image ) color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] return color_image class _ConvBnReLU(nn.Sequential): """ Cascade of 2D convolution, batch norm, and ReLU. """ BATCH_NORM = _BATCH_NORM def __init__( self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True ): super(_ConvBnReLU, self).__init__() self.add_module( "conv", nn.Conv2d( in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False ), ) self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999)) if relu: self.add_module("relu", nn.ReLU()) class _Bottleneck(nn.Module): """ Bottleneck block of MSRA ResNet. """ def __init__(self, in_ch, out_ch, stride, dilation, downsample): super(_Bottleneck, self).__init__() mid_ch = out_ch // _BOTTLENECK_EXPANSION self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True) self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True) self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False) self.shortcut = ( _ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False) if downsample else nn.Identity() ) def forward(self, x): h = self.reduce(x) h = self.conv3x3(h) h = self.increase(h) h += self.shortcut(x) return F.relu(h) class _ResLayer(nn.Sequential): """ Residual layer with multi grids """ def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None): super(_ResLayer, self).__init__() if multi_grids is None: multi_grids = [1 for _ in range(n_layers)] else: assert n_layers == len(multi_grids) # Downsampling is only in the first block for i in range(n_layers): self.add_module( "block{}".format(i + 1), _Bottleneck( in_ch=(in_ch if i == 0 else out_ch), out_ch=out_ch, stride=(stride if i == 0 else 1), dilation=dilation * multi_grids[i], downsample=(True if i == 0 else False), ), ) class _Stem(nn.Sequential): """ The 1st conv layer. Note that the max pooling is different from both MSRA and FAIR ResNet. """ def __init__(self, out_ch): super(_Stem, self).__init__() self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1)) self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)) class _ASPP(nn.Module): """ Atrous spatial pyramid pooling (ASPP) """ def __init__(self, in_ch, out_ch, rates): super(_ASPP, self).__init__() for i, rate in enumerate(rates): self.add_module( "c{}".format(i), nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True), ) for m in self.children(): nn.init.normal_(m.weight, mean=0, std=0.01) nn.init.constant_(m.bias, 0) def forward(self, x): return sum([stage(x) for stage in self.children()]) class MSC(nn.Module): """ Multi-scale inputs """ def __init__(self, base, scales=None): super(MSC, self).__init__() self.base = base if scales: self.scales = scales else: self.scales = [0.5, 0.75] def forward(self, x): # Original logits = self.base(x) _, _, H, W = logits.shape interp = lambda l: F.interpolate( l, size=(H, W), mode="bilinear", align_corners=False ) # Scaled logits_pyramid = [] for p in self.scales: h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False) logits_pyramid.append(self.base(h)) # Pixel-wise max logits_all = [logits] + [interp(l) for l in logits_pyramid] logits_max = torch.max(torch.stack(logits_all), dim=0)[0] return logits_max class DeepLabV2(nn.Sequential): """ DeepLab v2: Dilated ResNet + ASPP Output stride is fixed at 8 """ def __init__(self, n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]): super(DeepLabV2, self).__init__() ch = [64 * 2 ** p for p in range(6)] self.add_module("layer1", _Stem(ch[0])) self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1)) self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1)) self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2)) self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4)) self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates)) def freeze_bn(self): for m in self.modules(): if isinstance(m, _ConvBnReLU.BATCH_NORM): m.eval() def preprocessing(image, device): # Resize scale = 640 / max(image.shape[:2]) image = cv2.resize(image, dsize=None, fx=scale, fy=scale) raw_image = image.astype(np.uint8) # Subtract mean values image = image.astype(np.float32) image -= np.array( [ float(104.008), float(116.669), float(122.675), ] ) # Convert to torch.Tensor and add "batch" axis image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0) image = image.to(device) return image, raw_image # Model setup def seger(): model = MSC( base=DeepLabV2( n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] ), scales=[0.5, 0.75], ) state_dict = torch.load('models/deeplabv2_resnet101_msc-cocostuff164k-100000.pth') model.load_state_dict(state_dict) # to skip ASPP return model if __name__ == '__main__': device = 'cuda' model = seger() model.to(device) model.eval() with torch.no_grad(): im = cv2.imread('/group/30042/chongmou/ft_local/Diffusion/baselines/SPADE/datasets/coco_stuff/val_img/000000000785.jpg', cv2.IMREAD_COLOR) im, raw_im = preprocessing(im, 'cuda') _, _, H, W = im.shape # Image -> Probability map logits = model(im) logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) probs = F.softmax(logits, dim=1)[0] probs = probs.cpu().data.numpy() labelmap = np.argmax(probs, axis=0) print(labelmap.shape, np.max(labelmap), np.min(labelmap)) cv2.imwrite('mask.png', labelmap)