import torch import torchvision.transforms as transforms import criteria.deeplab as deeplab import PIL.Image as Image import torch.nn as nn import torch.nn.functional as F from configs import paths_config, global_config import numpy as np class Mask(nn.Module): def __init__(self, device="cpu"): """ | Class | Number | Class | Number | |------------|--------|-------|--------| | background | 0 | mouth | 10 | | skin | 1 | u_lip | 11 | | nose | 2 | l_lip | 12 | | eye_g | 3 | hair | 13 | | l_eye | 4 | hat | 14 | | r_eye | 5 | ear_r | 15 | | l_brow | 6 | neck_l| 16 | | r_brow | 7 | neck | 17 | | l_ear | 8 | cloth | 18 | | r_ear | 9 | """ super().__init__() self.seg_model = ( getattr(deeplab, "resnet101")( path=paths_config.deeplab, pretrained=True, num_classes=19, num_groups=32, weight_std=True, beta=False, device=device, ) .eval() .requires_grad_(False) ) ckpt = torch.load(paths_config.deeplab, map_location=device) state_dict = { k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k } self.seg_model.load_state_dict(state_dict) self.seg_model = self.seg_model.to(global_config.device) self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17] self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device) def get_labels(self, img): """Returns a mask from an input image""" data_transforms = transforms.Compose( [ transforms.Resize((513, 513)), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) img = data_transforms(img) with torch.no_grad(): out = self.seg_model(img) _, label = torch.max(out, 1) label = label.unsqueeze(0).type(torch.float32) label = ( F.interpolate(label, size=(256, 256), mode="nearest") .squeeze() .type(torch.LongTensor) ) return label def get_mask(self, label): mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float) for idx in self.labels: mask[label == idx] = 1 # smooth the mask with a mean convolution """mask = ( 1 - torch.clamp( torch.nn.functional.conv2d( 1 - mask[None, None, :, :], self.kernel, padding="same" ), 0, 1, ).squeeze() )""" """ mask = torch.clamp( torch.nn.functional.conv2d( mask[None, None, :, :], self.kernel, padding="same" ), 0, 1, ).squeeze()""" mask[label == 13] = 0.1 return mask def forward(self, real_imgs, generated_imgs): #return real_imgs, generated_imgs label = self.get_labels(real_imgs) mask = self.get_mask(label) real_imgs = real_imgs * mask generated_imgs = generated_imgs * mask """out = (real_imgs * mask).squeeze().detach() out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) Image.fromarray(out.cpu().numpy()).save("real_mask.png") out = (generated_imgs).squeeze().detach() out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) Image.fromarray(out.cpu().numpy()).save("generated_mask.png") mask = (mask).squeeze().detach() mask = mask.repeat(3, 1, 1) mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) Image.fromarray(mask.cpu().numpy()).save("mask.png")""" return real_imgs, generated_imgs