gavinyuan
add: PIPNet, arcface
b9be4e6
import os, cv2
import numpy as np
from PIL import Image, ImageFilter
import logging
import torch
import torch.nn as nn
import random
def get_label(data_name, label_file, task_type=None):
label_path = os.path.join('data', data_name, label_file)
with open(label_path, 'r') as f:
labels = f.readlines()
labels = [x.strip().split() for x in labels]
if len(labels[0])==1:
return labels
labels_new = []
for label in labels:
image_name = label[0]
target = label[1:]
target = np.array([float(x) for x in target])
if task_type is None:
labels_new.append([image_name, target])
else:
labels_new.append([image_name, task_type, target])
return labels_new
def get_meanface(meanface_file, num_nb):
with open(meanface_file) as f:
meanface = f.readlines()[0]
meanface = meanface.strip().split()
meanface = [float(x) for x in meanface]
meanface = np.array(meanface).reshape(-1, 2)
# each landmark predicts num_nb neighbors
meanface_indices = []
for i in range(meanface.shape[0]):
pt = meanface[i,:]
dists = np.sum(np.power(pt-meanface, 2), axis=1)
indices = np.argsort(dists)
meanface_indices.append(indices[1:1+num_nb])
# each landmark predicted by X neighbors, X varies
meanface_indices_reversed = {}
for i in range(meanface.shape[0]):
meanface_indices_reversed[i] = [[],[]]
for i in range(meanface.shape[0]):
for j in range(num_nb):
meanface_indices_reversed[meanface_indices[i][j]][0].append(i)
meanface_indices_reversed[meanface_indices[i][j]][1].append(j)
max_len = 0
for i in range(meanface.shape[0]):
tmp_len = len(meanface_indices_reversed[i][0])
if tmp_len > max_len:
max_len = tmp_len
# tricks, make them have equal length for efficient computation
for i in range(meanface.shape[0]):
tmp_len = len(meanface_indices_reversed[i][0])
meanface_indices_reversed[i][0] += meanface_indices_reversed[i][0]*10
meanface_indices_reversed[i][1] += meanface_indices_reversed[i][1]*10
meanface_indices_reversed[i][0] = meanface_indices_reversed[i][0][:max_len]
meanface_indices_reversed[i][1] = meanface_indices_reversed[i][1][:max_len]
# make the indices 1-dim
reverse_index1 = []
reverse_index2 = []
for i in range(meanface.shape[0]):
reverse_index1 += meanface_indices_reversed[i][0]
reverse_index2 += meanface_indices_reversed[i][1]
return meanface_indices, reverse_index1, reverse_index2, max_len
def compute_loss_pip(outputs_map1, outputs_map2, outputs_map3, outputs_local_x, outputs_local_y, outputs_nb_x, outputs_nb_y, labels_map1, labels_map2, labels_map3, labels_local_x, labels_local_y, labels_nb_x, labels_nb_y, masks_map1, masks_map2, masks_map3, masks_local_x, masks_local_y, masks_nb_x, masks_nb_y, criterion_cls, criterion_reg, num_nb):
tmp_batch, tmp_channel, tmp_height, tmp_width = outputs_map1.size()
labels_map1 = labels_map1.view(tmp_batch*tmp_channel, -1)
labels_max_ids = torch.argmax(labels_map1, 1)
labels_max_ids = labels_max_ids.view(-1, 1)
labels_max_ids_nb = labels_max_ids.repeat(1, num_nb).view(-1, 1)
outputs_local_x = outputs_local_x.view(tmp_batch*tmp_channel, -1)
outputs_local_x_select = torch.gather(outputs_local_x, 1, labels_max_ids)
outputs_local_y = outputs_local_y.view(tmp_batch*tmp_channel, -1)
outputs_local_y_select = torch.gather(outputs_local_y, 1, labels_max_ids)
outputs_nb_x = outputs_nb_x.view(tmp_batch*num_nb*tmp_channel, -1)
outputs_nb_x_select = torch.gather(outputs_nb_x, 1, labels_max_ids_nb)
outputs_nb_y = outputs_nb_y.view(tmp_batch*num_nb*tmp_channel, -1)
outputs_nb_y_select = torch.gather(outputs_nb_y, 1, labels_max_ids_nb)
labels_local_x = labels_local_x.view(tmp_batch*tmp_channel, -1)
labels_local_x_select = torch.gather(labels_local_x, 1, labels_max_ids)
labels_local_y = labels_local_y.view(tmp_batch*tmp_channel, -1)
labels_local_y_select = torch.gather(labels_local_y, 1, labels_max_ids)
labels_nb_x = labels_nb_x.view(tmp_batch*num_nb*tmp_channel, -1)
labels_nb_x_select = torch.gather(labels_nb_x, 1, labels_max_ids_nb)
labels_nb_y = labels_nb_y.view(tmp_batch*num_nb*tmp_channel, -1)
labels_nb_y_select = torch.gather(labels_nb_y, 1, labels_max_ids_nb)
masks_local_x = masks_local_x.view(tmp_batch*tmp_channel, -1)
masks_local_x_select = torch.gather(masks_local_x, 1, labels_max_ids)
masks_local_y = masks_local_y.view(tmp_batch*tmp_channel, -1)
masks_local_y_select = torch.gather(masks_local_y, 1, labels_max_ids)
masks_nb_x = masks_nb_x.view(tmp_batch*num_nb*tmp_channel, -1)
masks_nb_x_select = torch.gather(masks_nb_x, 1, labels_max_ids_nb)
masks_nb_y = masks_nb_y.view(tmp_batch*num_nb*tmp_channel, -1)
masks_nb_y_select = torch.gather(masks_nb_y, 1, labels_max_ids_nb)
##########################################
outputs_map1 = outputs_map1.view(tmp_batch*tmp_channel, -1)
outputs_map2 = outputs_map2.view(tmp_batch*tmp_channel, -1)
outputs_map3 = outputs_map3.view(tmp_batch*tmp_channel, -1)
labels_map2 = labels_map2.view(tmp_batch*tmp_channel, -1)
labels_map3 = labels_map3.view(tmp_batch*tmp_channel, -1)
masks_map1 = masks_map1.view(tmp_batch*tmp_channel, -1)
masks_map2 = masks_map2.view(tmp_batch*tmp_channel, -1)
masks_map3 = masks_map3.view(tmp_batch*tmp_channel, -1)
outputs_map = torch.cat([outputs_map1, outputs_map2, outputs_map3], 1)
labels_map = torch.cat([labels_map1, labels_map2, labels_map3], 1)
masks_map = torch.cat([masks_map1, masks_map2, masks_map3], 1)
loss_map = criterion_cls(outputs_map*masks_map, labels_map*masks_map)
if not masks_map.sum() == 0:
loss_map /= masks_map.sum()
##########################################
loss_x = criterion_reg(outputs_local_x_select*masks_local_x_select, labels_local_x_select*masks_local_x_select)
if not masks_local_x_select.sum() == 0:
loss_x /= masks_local_x_select.sum()
loss_y = criterion_reg(outputs_local_y_select*masks_local_y_select, labels_local_y_select*masks_local_y_select)
if not masks_local_y_select.sum() == 0:
loss_y /= masks_local_y_select.sum()
loss_nb_x = criterion_reg(outputs_nb_x_select*masks_nb_x_select, labels_nb_x_select*masks_nb_x_select)
if not masks_nb_x_select.sum() == 0:
loss_nb_x /= masks_nb_x_select.sum()
loss_nb_y = criterion_reg(outputs_nb_y_select*masks_nb_y_select, labels_nb_y_select*masks_nb_y_select)
if not masks_nb_y_select.sum() == 0:
loss_nb_y /= masks_nb_y_select.sum()
return loss_map, loss_x, loss_y, loss_nb_x, loss_nb_y
def train_model(det_head, net, train_loader, criterion_cls, criterion_reg, cls_loss_weight, reg_loss_weight, num_nb, optimizer, num_epochs, scheduler, save_dir, save_interval, device):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
logging.info('-' * 10)
net.train()
epoch_loss = 0.0
for i, data in enumerate(train_loader):
if det_head == 'pip':
inputs, labels_map1, labels_map2, labels_map3, labels_x, labels_y, labels_nb_x, labels_nb_y, masks_map1, masks_map2, masks_map3, masks_x, masks_y, masks_nb_x, masks_nb_y = data
inputs = inputs.to(device)
labels_map1 = labels_map1.to(device)
labels_map2 = labels_map2.to(device)
labels_map3 = labels_map3.to(device)
labels_x = labels_x.to(device)
labels_y = labels_y.to(device)
labels_nb_x = labels_nb_x.to(device)
labels_nb_y = labels_nb_y.to(device)
masks_map1 = masks_map1.to(device)
masks_map2 = masks_map2.to(device)
masks_map3 = masks_map3.to(device)
masks_x = masks_x.to(device)
masks_y = masks_y.to(device)
masks_nb_x = masks_nb_x.to(device)
masks_nb_y = masks_nb_y.to(device)
outputs_map1, outputs_map2, outputs_map3, outputs_x, outputs_y, outputs_nb_x, outputs_nb_y = net(inputs)
loss_map, loss_x, loss_y, loss_nb_x, loss_nb_y = compute_loss_pip(outputs_map1, outputs_map2, outputs_map3, outputs_x, outputs_y, outputs_nb_x, outputs_nb_y, labels_map1, labels_map2, labels_map3, labels_x, labels_y, labels_nb_x, labels_nb_y, masks_map1, masks_map2, masks_map3, masks_x, masks_y, masks_nb_x, masks_nb_y, criterion_cls, criterion_reg, num_nb)
loss = cls_loss_weight*loss_map + reg_loss_weight*loss_x + reg_loss_weight*loss_y + reg_loss_weight*loss_nb_x + reg_loss_weight*loss_nb_y
else:
print('No such head:', det_head)
exit(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%10 == 0:
if det_head == 'pip':
print('[Epoch {:d}/{:d}, Batch {:d}/{:d}] <Total loss: {:.6f}> <map loss: {:.6f}> <x loss: {:.6f}> <y loss: {:.6f}> <nbx loss: {:.6f}> <nby loss: {:.6f}>'.format(
epoch, num_epochs-1, i, len(train_loader)-1, loss.item(), cls_loss_weight*loss_map.item(), reg_loss_weight*loss_x.item(), reg_loss_weight*loss_y.item(), reg_loss_weight*loss_nb_x.item(), reg_loss_weight*loss_nb_y.item()))
logging.info('[Epoch {:d}/{:d}, Batch {:d}/{:d}] <Total loss: {:.6f}> <map loss: {:.6f}> <x loss: {:.6f}> <y loss: {:.6f}> <nbx loss: {:.6f}> <nby loss: {:.6f}>'.format(
epoch, num_epochs-1, i, len(train_loader)-1, loss.item(), cls_loss_weight*loss_map.item(), reg_loss_weight*loss_x.item(), reg_loss_weight*loss_y.item(), reg_loss_weight*loss_nb_x.item(), reg_loss_weight*loss_nb_y.item()))
else:
print('No such head:', det_head)
exit(0)
epoch_loss += loss.item()
epoch_loss /= len(train_loader)
if epoch%(save_interval-1) == 0 and epoch > 0:
filename = os.path.join(save_dir, 'epoch%d.pth' % epoch)
torch.save(net.state_dict(), filename)
print(filename, 'saved')
scheduler.step()
return net
def forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb):
net.eval()
with torch.no_grad():
outputs_cls1, outputs_cls2, outputs_cls3, outputs_x, outputs_y, outputs_nb_x, outputs_nb_y = net(inputs)
tmp_batch, tmp_channel, tmp_height, tmp_width = outputs_cls1.size()
assert tmp_batch == 1
outputs_cls1 = outputs_cls1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(outputs_cls1, 1)
max_cls = torch.max(outputs_cls1, 1)[0]
max_ids = max_ids.view(-1, 1)
max_ids_nb = max_ids.repeat(1, num_nb).view(-1, 1)
outputs_x = outputs_x.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(outputs_x, 1, max_ids)
outputs_x_select = outputs_x_select.squeeze(1)
outputs_y = outputs_y.view(tmp_batch*tmp_channel, -1)
outputs_y_select = torch.gather(outputs_y, 1, max_ids)
outputs_y_select = outputs_y_select.squeeze(1)
outputs_nb_x = outputs_nb_x.view(tmp_batch*num_nb*tmp_channel, -1)
outputs_nb_x_select = torch.gather(outputs_nb_x, 1, max_ids_nb)
outputs_nb_x_select = outputs_nb_x_select.squeeze(1).view(-1, num_nb)
outputs_nb_y = outputs_nb_y.view(tmp_batch*num_nb*tmp_channel, -1)
outputs_nb_y_select = torch.gather(outputs_nb_y, 1, max_ids_nb)
outputs_nb_y_select = outputs_nb_y_select.squeeze(1).view(-1, num_nb)
tmp_x = (max_ids%tmp_width).view(-1,1).float()+outputs_x_select.view(-1,1)
tmp_y = (max_ids//tmp_width).view(-1,1).float()+outputs_y_select.view(-1,1)
tmp_x /= 1.0 * input_size / net_stride
tmp_y /= 1.0 * input_size / net_stride
tmp_nb_x = (max_ids%tmp_width).view(-1,1).float()+outputs_nb_x_select
tmp_nb_y = (max_ids//tmp_width).view(-1,1).float()+outputs_nb_y_select
tmp_nb_x = tmp_nb_x.view(-1, num_nb)
tmp_nb_y = tmp_nb_y.view(-1, num_nb)
tmp_nb_x /= 1.0 * input_size / net_stride
tmp_nb_y /= 1.0 * input_size / net_stride
return tmp_x, tmp_y, tmp_nb_x, tmp_nb_y, [outputs_cls1, outputs_cls2, outputs_cls3], max_cls
def compute_nme(lms_pred, lms_gt, norm):
lms_pred = lms_pred.reshape((-1, 2))
lms_gt = lms_gt.reshape((-1, 2))
nme = np.mean(np.linalg.norm(lms_pred - lms_gt, axis=1)) / norm
return nme