Spaces:
Sleeping
Sleeping
File size: 3,631 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# 2022.07.19 - Changed for CLIFF
# Huawei Technologies Co., Ltd.
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2019, University of Pennsylvania, Max Planck Institute for Intelligent Systems
# This program is free software; you can redistribute it and/or modify it
# under the terms of the MIT license.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.
# This script is borrowed and extended from SPIN
import torch
import torch.nn as nn
import numpy as np
import math
from common.imutils import rot6d_to_rotmat
from models.backbones.resnet import ResNet
class CLIFF(nn.Module):
""" SMPL Iterative Regressor with ResNet50 backbone"""
def __init__(self, smpl_mean_params, img_feat_num=2048):
super(CLIFF, self).__init__()
self.encoder = ResNet(layers=[3, 4, 6, 3])
npose = 24 * 6
nshape = 10
ncam = 3
nbbox = 3
fc1_feat_num = 1024
fc2_feat_num = 1024
final_feat_num = fc2_feat_num
reg_in_feat_num = img_feat_num + nbbox + npose + nshape + ncam
self.fc1 = nn.Linear(reg_in_feat_num, fc1_feat_num)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(fc1_feat_num, fc2_feat_num)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(final_feat_num, npose)
self.decshape = nn.Linear(final_feat_num, nshape)
self.deccam = nn.Linear(final_feat_num, ncam)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
def forward(self, x, bbox, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
batch_size = x.shape[0]
if init_pose is None:
init_pose = self.init_pose.expand(batch_size, -1)
if init_shape is None:
init_shape = self.init_shape.expand(batch_size, -1)
if init_cam is None:
init_cam = self.init_cam.expand(batch_size, -1)
xf = self.encoder(x)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(n_iter):
xc = torch.cat([xf, bbox, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
return pred_rotmat, pred_shape, pred_cam
|