ac5113's picture
added files
99a05f0
raw
history blame
3.63 kB
# 2022.07.19 - Changed for CLIFF
# Huawei Technologies Co., Ltd.
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2019, University of Pennsylvania, Max Planck Institute for Intelligent Systems
# This program is free software; you can redistribute it and/or modify it
# under the terms of the MIT license.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.
# This script is borrowed and extended from SPIN
import torch
import torch.nn as nn
import numpy as np
import math
from common.imutils import rot6d_to_rotmat
from models.backbones.resnet import ResNet
class CLIFF(nn.Module):
""" SMPL Iterative Regressor with ResNet50 backbone"""
def __init__(self, smpl_mean_params, img_feat_num=2048):
super(CLIFF, self).__init__()
self.encoder = ResNet(layers=[3, 4, 6, 3])
npose = 24 * 6
nshape = 10
ncam = 3
nbbox = 3
fc1_feat_num = 1024
fc2_feat_num = 1024
final_feat_num = fc2_feat_num
reg_in_feat_num = img_feat_num + nbbox + npose + nshape + ncam
self.fc1 = nn.Linear(reg_in_feat_num, fc1_feat_num)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(fc1_feat_num, fc2_feat_num)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(final_feat_num, npose)
self.decshape = nn.Linear(final_feat_num, nshape)
self.deccam = nn.Linear(final_feat_num, ncam)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
def forward(self, x, bbox, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
batch_size = x.shape[0]
if init_pose is None:
init_pose = self.init_pose.expand(batch_size, -1)
if init_shape is None:
init_shape = self.init_shape.expand(batch_size, -1)
if init_cam is None:
init_cam = self.init_cam.expand(batch_size, -1)
xf = self.encoder(x)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(n_iter):
xc = torch.cat([xf, bbox, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
return pred_rotmat, pred_shape, pred_cam