echen01
add PTI
2e34814
raw
history blame
4.1 kB
import torch
import torchvision.transforms as transforms
import criteria.deeplab as deeplab
import PIL.Image as Image
import torch.nn as nn
import torch.nn.functional as F
from configs import paths_config, global_config
import numpy as np
class Mask(nn.Module):
def __init__(self):
"""
| Class | Number | Class | Number |
|------------|--------|-------|--------|
| background | 0 | mouth | 10 |
| skin | 1 | u_lip | 11 |
| nose | 2 | l_lip | 12 |
| eye_g | 3 | hair | 13 |
| l_eye | 4 | hat | 14 |
| r_eye | 5 | ear_r | 15 |
| l_brow | 6 | neck_l| 16 |
| r_brow | 7 | neck | 17 |
| l_ear | 8 | cloth | 18 |
| r_ear | 9 |
"""
super().__init__()
self.seg_model = (
getattr(deeplab, "resnet101")(
path=paths_config.deeplab,
pretrained=True,
num_classes=19,
num_groups=32,
weight_std=True,
beta=False,
)
.eval()
.requires_grad_(False)
)
ckpt = torch.load(paths_config.deeplab, map_location=global_config.device)
state_dict = {
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
}
self.seg_model.load_state_dict(state_dict)
self.seg_model = self.seg_model.to(global_config.device)
self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17]
self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device)
def get_labels(self, img):
"""Returns a mask from an input image"""
data_transforms = transforms.Compose(
[
transforms.Resize((513, 513)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img = data_transforms(img)
with torch.no_grad():
out = self.seg_model(img)
_, label = torch.max(out, 1)
label = label.unsqueeze(0).type(torch.float32)
label = (
F.interpolate(label, size=(256, 256), mode="nearest")
.squeeze()
.type(torch.LongTensor)
)
return label
def get_mask(self, label):
mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float)
for idx in self.labels:
mask[label == idx] = 1
# smooth the mask with a mean convolution
"""mask = (
1
- torch.clamp(
torch.nn.functional.conv2d(
1 - mask[None, None, :, :], self.kernel, padding="same"
),
0,
1,
).squeeze()
)"""
""" mask = torch.clamp(
torch.nn.functional.conv2d(
mask[None, None, :, :], self.kernel, padding="same"
),
0,
1,
).squeeze()"""
mask[label == 13] = 0.1
return mask
def forward(self, real_imgs, generated_imgs):
#return real_imgs, generated_imgs
label = self.get_labels(real_imgs)
mask = self.get_mask(label)
real_imgs = real_imgs * mask
generated_imgs = generated_imgs * mask
"""out = (real_imgs * mask).squeeze().detach()
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(out.cpu().numpy()).save("real_mask.png")
out = (generated_imgs).squeeze().detach()
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(out.cpu().numpy()).save("generated_mask.png")
mask = (mask).squeeze().detach()
mask = mask.repeat(3, 1, 1)
mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(mask.cpu().numpy()).save("mask.png")"""
return real_imgs, generated_imgs