echen01
fix deeplab device
f7bf9fb
raw
history blame
4.13 kB
import torch
import torchvision.transforms as transforms
import criteria.deeplab as deeplab
import PIL.Image as Image
import torch.nn as nn
import torch.nn.functional as F
from configs import paths_config, global_config
import numpy as np
class Mask(nn.Module):
def __init__(self, device="cpu"):
"""
| Class | Number | Class | Number |
|------------|--------|-------|--------|
| background | 0 | mouth | 10 |
| skin | 1 | u_lip | 11 |
| nose | 2 | l_lip | 12 |
| eye_g | 3 | hair | 13 |
| l_eye | 4 | hat | 14 |
| r_eye | 5 | ear_r | 15 |
| l_brow | 6 | neck_l| 16 |
| r_brow | 7 | neck | 17 |
| l_ear | 8 | cloth | 18 |
| r_ear | 9 |
"""
super().__init__()
self.seg_model = (
getattr(deeplab, "resnet101")(
path=paths_config.deeplab,
pretrained=True,
num_classes=19,
num_groups=32,
weight_std=True,
beta=False,
device=device,
)
.eval()
.requires_grad_(False)
)
ckpt = torch.load(paths_config.deeplab, map_location=device)
state_dict = {
k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
}
self.seg_model.load_state_dict(state_dict)
self.seg_model = self.seg_model.to(global_config.device)
self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17]
self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device)
def get_labels(self, img):
"""Returns a mask from an input image"""
data_transforms = transforms.Compose(
[
transforms.Resize((513, 513)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img = data_transforms(img)
with torch.no_grad():
out = self.seg_model(img)
_, label = torch.max(out, 1)
label = label.unsqueeze(0).type(torch.float32)
label = (
F.interpolate(label, size=(256, 256), mode="nearest")
.squeeze()
.type(torch.LongTensor)
)
return label
def get_mask(self, label):
mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float)
for idx in self.labels:
mask[label == idx] = 1
# smooth the mask with a mean convolution
"""mask = (
1
- torch.clamp(
torch.nn.functional.conv2d(
1 - mask[None, None, :, :], self.kernel, padding="same"
),
0,
1,
).squeeze()
)"""
""" mask = torch.clamp(
torch.nn.functional.conv2d(
mask[None, None, :, :], self.kernel, padding="same"
),
0,
1,
).squeeze()"""
mask[label == 13] = 0.1
return mask
def forward(self, real_imgs, generated_imgs):
#return real_imgs, generated_imgs
label = self.get_labels(real_imgs)
mask = self.get_mask(label)
real_imgs = real_imgs * mask
generated_imgs = generated_imgs * mask
"""out = (real_imgs * mask).squeeze().detach()
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(out.cpu().numpy()).save("real_mask.png")
out = (generated_imgs).squeeze().detach()
out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(out.cpu().numpy()).save("generated_mask.png")
mask = (mask).squeeze().detach()
mask = mask.repeat(3, 1, 1)
mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
Image.fromarray(mask.cpu().numpy()).save("mask.png")"""
return real_imgs, generated_imgs