Spaces:
Sleeping
Sleeping
# 2022.07.19 - Changed for CLIFF | |
# Huawei Technologies Co., Ltd. | |
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. | |
# Copyright (c) 2019, University of Pennsylvania, Max Planck Institute for Intelligent Systems | |
# This program is free software; you can redistribute it and/or modify it | |
# under the terms of the MIT license. | |
# This program is distributed in the hope that it will be useful, but WITHOUT ANY | |
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A | |
# PARTICULAR PURPOSE. See the MIT License for more details. | |
# This script is borrowed and extended from SPIN | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import math | |
import os.path as osp | |
from common.imutils import rot6d_to_rotmat | |
from models.backbones.hrnet.cls_hrnet import HighResolutionNet | |
from models.backbones.hrnet.hrnet_config import cfg | |
from models.backbones.hrnet.hrnet_config import update_config | |
class CLIFF(nn.Module): | |
""" SMPL Iterative Regressor with ResNet50 backbone""" | |
def __init__(self, smpl_mean_params, img_feat_num=2048): | |
super(CLIFF, self).__init__() | |
curr_dir = osp.dirname(osp.abspath(__file__)) | |
config_file = osp.join(curr_dir, "../backbones/hrnet/models/cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml") | |
update_config(cfg, config_file) | |
self.encoder = HighResolutionNet(cfg) | |
npose = 24 * 6 | |
nshape = 10 | |
ncam = 3 | |
nbbox = 3 | |
fc1_feat_num = 1024 | |
fc2_feat_num = 1024 | |
final_feat_num = fc2_feat_num | |
reg_in_feat_num = img_feat_num + nbbox + npose + nshape + ncam | |
# CUDA Error: an illegal memory access was encountered | |
# the above error will occur, if use mobilenet v3 with BN, so don't use BN | |
self.fc1 = nn.Linear(reg_in_feat_num, fc1_feat_num) | |
self.drop1 = nn.Dropout() | |
self.fc2 = nn.Linear(fc1_feat_num, fc2_feat_num) | |
self.drop2 = nn.Dropout() | |
self.decpose = nn.Linear(final_feat_num, npose) | |
self.decshape = nn.Linear(final_feat_num, nshape) | |
self.deccam = nn.Linear(final_feat_num, ncam) | |
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) | |
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) | |
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
mean_params = np.load(smpl_mean_params) | |
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) | |
self.register_buffer('init_pose', init_pose) | |
self.register_buffer('init_shape', init_shape) | |
self.register_buffer('init_cam', init_cam) | |
def forward(self, x, bbox, init_pose=None, init_shape=None, init_cam=None, n_iter=3): | |
batch_size = x.shape[0] | |
if init_pose is None: | |
init_pose = self.init_pose.expand(batch_size, -1) | |
if init_shape is None: | |
init_shape = self.init_shape.expand(batch_size, -1) | |
if init_cam is None: | |
init_cam = self.init_cam.expand(batch_size, -1) | |
xf = self.encoder(x) | |
pred_pose = init_pose | |
pred_shape = init_shape | |
pred_cam = init_cam | |
for i in range(n_iter): | |
xc = torch.cat([xf, bbox, pred_pose, pred_shape, pred_cam], 1) | |
xc = self.fc1(xc) | |
xc = self.drop1(xc) | |
xc = self.fc2(xc) | |
xc = self.drop2(xc) | |
pred_pose = self.decpose(xc) + pred_pose | |
pred_shape = self.decshape(xc) + pred_shape | |
pred_cam = self.deccam(xc) + pred_cam | |
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) | |
return pred_rotmat, pred_shape, pred_cam | |