Spaces:
Runtime error
Runtime error
import torch | |
import torchvision.transforms as transforms | |
import criteria.deeplab as deeplab | |
import PIL.Image as Image | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from configs import paths_config, global_config | |
import numpy as np | |
class Mask(nn.Module): | |
def __init__(self, device="cpu"): | |
""" | |
| Class | Number | Class | Number | | |
|------------|--------|-------|--------| | |
| background | 0 | mouth | 10 | | |
| skin | 1 | u_lip | 11 | | |
| nose | 2 | l_lip | 12 | | |
| eye_g | 3 | hair | 13 | | |
| l_eye | 4 | hat | 14 | | |
| r_eye | 5 | ear_r | 15 | | |
| l_brow | 6 | neck_l| 16 | | |
| r_brow | 7 | neck | 17 | | |
| l_ear | 8 | cloth | 18 | | |
| r_ear | 9 | | |
""" | |
super().__init__() | |
self.seg_model = ( | |
getattr(deeplab, "resnet101")( | |
path=paths_config.deeplab, | |
pretrained=True, | |
num_classes=19, | |
num_groups=32, | |
weight_std=True, | |
beta=False, | |
device=device, | |
) | |
.eval() | |
.requires_grad_(False) | |
) | |
ckpt = torch.load(paths_config.deeplab, map_location=device) | |
state_dict = { | |
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k | |
} | |
self.seg_model.load_state_dict(state_dict) | |
self.seg_model = self.seg_model.to(global_config.device) | |
self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17] | |
self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device) | |
def get_labels(self, img): | |
"""Returns a mask from an input image""" | |
data_transforms = transforms.Compose( | |
[ | |
transforms.Resize((513, 513)), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
img = data_transforms(img) | |
with torch.no_grad(): | |
out = self.seg_model(img) | |
_, label = torch.max(out, 1) | |
label = label.unsqueeze(0).type(torch.float32) | |
label = ( | |
F.interpolate(label, size=(256, 256), mode="nearest") | |
.squeeze() | |
.type(torch.LongTensor) | |
) | |
return label | |
def get_mask(self, label): | |
mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float) | |
for idx in self.labels: | |
mask[label == idx] = 1 | |
# smooth the mask with a mean convolution | |
"""mask = ( | |
1 | |
- torch.clamp( | |
torch.nn.functional.conv2d( | |
1 - mask[None, None, :, :], self.kernel, padding="same" | |
), | |
0, | |
1, | |
).squeeze() | |
)""" | |
""" mask = torch.clamp( | |
torch.nn.functional.conv2d( | |
mask[None, None, :, :], self.kernel, padding="same" | |
), | |
0, | |
1, | |
).squeeze()""" | |
mask[label == 13] = 0.1 | |
return mask | |
def forward(self, real_imgs, generated_imgs): | |
#return real_imgs, generated_imgs | |
label = self.get_labels(real_imgs) | |
mask = self.get_mask(label) | |
real_imgs = real_imgs * mask | |
generated_imgs = generated_imgs * mask | |
"""out = (real_imgs * mask).squeeze().detach() | |
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
Image.fromarray(out.cpu().numpy()).save("real_mask.png") | |
out = (generated_imgs).squeeze().detach() | |
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
Image.fromarray(out.cpu().numpy()).save("generated_mask.png") | |
mask = (mask).squeeze().detach() | |
mask = mask.repeat(3, 1, 1) | |
mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
Image.fromarray(mask.cpu().numpy()).save("mask.png")""" | |
return real_imgs, generated_imgs | |