HaMeR / mmpose /models /detectors /multiview_pose.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame
34.1 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import load_checkpoint
from mmpose.core.camera import SimpleCameraTorch
from mmpose.core.post_processing.post_transforms import (
affine_transform_torch, get_affine_transform)
from .. import builder
from ..builder import POSENETS
from .base import BasePose
class ProjectLayer(nn.Module):
def __init__(self, image_size, heatmap_size):
"""Project layer to get voxel feature. Adapted from
https://github.com/microsoft/voxelpose-
pytorch/blob/main/lib/models/project_layer.py.
Args:
image_size (int or list): input size of the 2D model
heatmap_size (int or list): output size of the 2D model
"""
super(ProjectLayer, self).__init__()
self.image_size = image_size
self.heatmap_size = heatmap_size
if isinstance(self.image_size, int):
self.image_size = [self.image_size, self.image_size]
if isinstance(self.heatmap_size, int):
self.heatmap_size = [self.heatmap_size, self.heatmap_size]
def compute_grid(self, box_size, box_center, num_bins, device=None):
if isinstance(box_size, int) or isinstance(box_size, float):
box_size = [box_size, box_size, box_size]
if isinstance(num_bins, int):
num_bins = [num_bins, num_bins, num_bins]
grid_1D_x = torch.linspace(
-box_size[0] / 2, box_size[0] / 2, num_bins[0], device=device)
grid_1D_y = torch.linspace(
-box_size[1] / 2, box_size[1] / 2, num_bins[1], device=device)
grid_1D_z = torch.linspace(
-box_size[2] / 2, box_size[2] / 2, num_bins[2], device=device)
grid_x, grid_y, grid_z = torch.meshgrid(
grid_1D_x + box_center[0],
grid_1D_y + box_center[1],
grid_1D_z + box_center[2],
)
grid_x = grid_x.contiguous().view(-1, 1)
grid_y = grid_y.contiguous().view(-1, 1)
grid_z = grid_z.contiguous().view(-1, 1)
grid = torch.cat([grid_x, grid_y, grid_z], dim=1)
return grid
def get_voxel(self, feature_maps, meta, grid_size, grid_center, cube_size):
device = feature_maps[0].device
batch_size = feature_maps[0].shape[0]
num_channels = feature_maps[0].shape[1]
num_bins = cube_size[0] * cube_size[1] * cube_size[2]
n = len(feature_maps)
cubes = torch.zeros(
batch_size, num_channels, 1, num_bins, n, device=device)
w, h = self.heatmap_size
grids = torch.zeros(batch_size, num_bins, 3, device=device)
bounding = torch.zeros(batch_size, 1, 1, num_bins, n, device=device)
for i in range(batch_size):
if len(grid_center[0]) == 3 or grid_center[i][3] >= 0:
if len(grid_center) == 1:
grid = self.compute_grid(
grid_size, grid_center[0], cube_size, device=device)
else:
grid = self.compute_grid(
grid_size, grid_center[i], cube_size, device=device)
grids[i:i + 1] = grid
for c in range(n):
center = meta[i]['center'][c]
scale = meta[i]['scale'][c]
width, height = center * 2
trans = torch.as_tensor(
get_affine_transform(center, scale / 200.0, 0,
self.image_size),
dtype=torch.float,
device=device)
cam_param = meta[i]['camera'][c].copy()
single_view_camera = SimpleCameraTorch(
param=cam_param, device=device)
xy = single_view_camera.world_to_pixel(grid)
bounding[i, 0, 0, :, c] = (xy[:, 0] >= 0) & (
xy[:, 1] >= 0) & (xy[:, 0] < width) & (
xy[:, 1] < height)
xy = torch.clamp(xy, -1.0, max(width, height))
xy = affine_transform_torch(xy, trans)
xy = xy * torch.tensor(
[w, h], dtype=torch.float,
device=device) / torch.tensor(
self.image_size, dtype=torch.float, device=device)
sample_grid = xy / torch.tensor([w - 1, h - 1],
dtype=torch.float,
device=device) * 2.0 - 1.0
sample_grid = torch.clamp(
sample_grid.view(1, 1, num_bins, 2), -1.1, 1.1)
cubes[i:i + 1, :, :, :, c] += F.grid_sample(
feature_maps[c][i:i + 1, :, :, :],
sample_grid,
align_corners=True)
cubes = torch.sum(
torch.mul(cubes, bounding), dim=-1) / (
torch.sum(bounding, dim=-1) + 1e-6)
cubes[cubes != cubes] = 0.0
cubes = cubes.clamp(0.0, 1.0)
cubes = cubes.view(batch_size, num_channels, cube_size[0],
cube_size[1], cube_size[2])
return cubes, grids
def forward(self, feature_maps, meta, grid_size, grid_center, cube_size):
cubes, grids = self.get_voxel(feature_maps, meta, grid_size,
grid_center, cube_size)
return cubes, grids
@POSENETS.register_module()
class DetectAndRegress(BasePose):
"""DetectAndRegress approach for multiview human pose detection.
Args:
backbone (ConfigDict): Dictionary to construct the 2D pose detector
human_detector (ConfigDict): dictionary to construct human detector
pose_regressor (ConfigDict): dictionary to construct pose regressor
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
pretrained (str): Path to the pretrained 2D model. Default: None.
freeze_2d (bool): Whether to freeze the 2D model in training.
Default: True.
"""
def __init__(self,
backbone,
human_detector,
pose_regressor,
train_cfg=None,
test_cfg=None,
pretrained=None,
freeze_2d=True):
super(DetectAndRegress, self).__init__()
if backbone is not None:
self.backbone = builder.build_posenet(backbone)
if self.training and pretrained is not None:
load_checkpoint(self.backbone, pretrained)
else:
self.backbone = None
self.freeze_2d = freeze_2d
self.human_detector = builder.MODELS.build(human_detector)
self.pose_regressor = builder.MODELS.build(pose_regressor)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
@staticmethod
def _freeze(model):
"""Freeze parameters."""
model.eval()
for param in model.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Sets the module in training mode.
Args:
mode (bool): whether to set training mode (``True``)
or evaluation mode (``False``). Default: ``True``.
Returns:
Module: self
"""
super().train(mode)
if mode and self.freeze_2d and self.backbone is not None:
self._freeze(self.backbone)
return self
def forward(self,
img=None,
img_metas=None,
return_loss=True,
targets=None,
masks=None,
targets_3d=None,
input_heatmaps=None,
**kwargs):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
targets (list(torch.Tensor[NxKxHxW])):
Multi-camera target feature_maps of the 2D model.
masks (list(torch.Tensor[NxHxW])):
Multi-camera masks of the input to the 2D model.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
input_heatmaps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps when the 2D model is not available.
Default: None.
**kwargs:
Returns:
dict: if 'return_loss' is true, then return losses.
Otherwise, return predicted poses, human centers and sample_id
"""
if return_loss:
return self.forward_train(img, img_metas, targets, masks,
targets_3d, input_heatmaps)
else:
return self.forward_test(img, img_metas, input_heatmaps)
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_batch (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self.forward(**data_batch)
loss, log_vars = self._parse_losses(losses)
if 'img' in data_batch:
batch_size = data_batch['img'][0].shape[0]
else:
assert 'input_heatmaps' in data_batch
batch_size = data_batch['input_heatmaps'][0][0].shape[0]
outputs = dict(loss=loss, log_vars=log_vars, num_samples=batch_size)
return outputs
def forward_train(self,
img,
img_metas,
targets=None,
masks=None,
targets_3d=None,
input_heatmaps=None):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
targets (list(torch.Tensor[NxKxHxW])):
Multi-camera target feature_maps of the 2D model.
masks (list(torch.Tensor[NxHxW])):
Multi-camera masks of the input to the 2D model.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
input_heatmaps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps when the 2D model is not available.
Default: None.
Returns:
dict: losses.
"""
if self.backbone is None:
assert input_heatmaps is not None
feature_maps = []
for input_heatmap in input_heatmaps:
feature_maps.append(input_heatmap[0])
else:
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
losses = dict()
human_candidates, human_loss = self.human_detector.forward_train(
None, img_metas, feature_maps, targets_3d, return_preds=True)
losses.update(human_loss)
pose_loss = self.pose_regressor(
None,
img_metas,
return_loss=True,
feature_maps=feature_maps,
human_candidates=human_candidates)
losses.update(pose_loss)
if not self.freeze_2d:
losses_2d = {}
heatmaps_tensor = torch.cat(feature_maps, dim=0)
targets_tensor = torch.cat(targets, dim=0)
masks_tensor = torch.cat(masks, dim=0)
losses_2d_ = self.backbone.get_loss(heatmaps_tensor,
targets_tensor, masks_tensor)
for k, v in losses_2d_.items():
losses_2d[k + '_2d'] = v
losses.update(losses_2d)
return losses
def forward_test(
self,
img,
img_metas,
input_heatmaps=None,
):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
input_heatmaps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps when the 2D model is not available.
Default: None.
Returns:
dict: predicted poses, human centers and sample_id
"""
if self.backbone is None:
assert input_heatmaps is not None
feature_maps = []
for input_heatmap in input_heatmaps:
feature_maps.append(input_heatmap[0])
else:
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
human_candidates = self.human_detector.forward_test(
None, img_metas, feature_maps)
human_poses = self.pose_regressor(
None,
img_metas,
return_loss=False,
feature_maps=feature_maps,
human_candidates=human_candidates)
result = {}
result['pose_3d'] = human_poses.cpu().numpy()
result['human_detection_3d'] = human_candidates.cpu().numpy()
result['sample_id'] = [img_meta['sample_id'] for img_meta in img_metas]
return result
def show_result(self, **kwargs):
"""Visualize the results."""
raise NotImplementedError
def forward_dummy(self, img, input_heatmaps=None, num_candidates=5):
"""Used for computing network FLOPs."""
if self.backbone is None:
assert input_heatmaps is not None
feature_maps = []
for input_heatmap in input_heatmaps:
feature_maps.append(input_heatmap[0])
else:
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
_ = self.human_detector.forward_dummy(feature_maps)
_ = self.pose_regressor.forward_dummy(feature_maps, num_candidates)
@POSENETS.register_module()
class VoxelSinglePose(BasePose):
"""VoxelPose Please refer to the `paper <https://arxiv.org/abs/2004.06239>`
for details.
Args:
image_size (list): input size of the 2D model.
heatmap_size (list): output size of the 2D model.
sub_space_size (list): Size of the cuboid human proposal.
sub_cube_size (list): Size of the input volume to the pose net.
pose_net (ConfigDict): Dictionary to construct the pose net.
pose_head (ConfigDict): Dictionary to construct the pose head.
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
"""
def __init__(
self,
image_size,
heatmap_size,
sub_space_size,
sub_cube_size,
num_joints,
pose_net,
pose_head,
train_cfg=None,
test_cfg=None,
):
super(VoxelSinglePose, self).__init__()
self.project_layer = ProjectLayer(image_size, heatmap_size)
self.pose_net = builder.build_backbone(pose_net)
self.pose_head = builder.build_head(pose_head)
self.sub_space_size = sub_space_size
self.sub_cube_size = sub_cube_size
self.num_joints = num_joints
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def forward(self,
img,
img_metas,
return_loss=True,
feature_maps=None,
human_candidates=None,
**kwargs):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
human_candidates (torch.Tensor[NxPx5]):
Human candidates.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
"""
if return_loss:
return self.forward_train(img, img_metas, feature_maps,
human_candidates)
else:
return self.forward_test(img, img_metas, feature_maps,
human_candidates)
def forward_train(self,
img,
img_metas,
feature_maps=None,
human_candidates=None,
return_preds=False,
**kwargs):
"""Defines the computation performed at training.
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
human_candidates (torch.Tensor[NxPx5]):
Human candidates.
return_preds (bool): Whether to return prediction results
Returns:
dict: losses.
"""
batch_size, num_candidates, _ = human_candidates.shape
pred = human_candidates.new_zeros(batch_size, num_candidates,
self.num_joints, 5)
pred[:, :, :, 3:] = human_candidates[:, :, None, 3:]
device = feature_maps[0].device
gt_3d = torch.stack([
torch.tensor(img_meta['joints_3d'], device=device)
for img_meta in img_metas
])
gt_3d_vis = torch.stack([
torch.tensor(img_meta['joints_3d_visible'], device=device)
for img_meta in img_metas
])
valid_preds = []
valid_targets = []
valid_weights = []
for n in range(num_candidates):
index = pred[:, n, 0, 3] >= 0
num_valid = index.sum()
if num_valid > 0:
pose_input_cube, coordinates \
= self.project_layer(feature_maps,
img_metas,
self.sub_space_size,
human_candidates[:, n, :3],
self.sub_cube_size)
pose_heatmaps_3d = self.pose_net(pose_input_cube)
pose_3d = self.pose_head(pose_heatmaps_3d[index],
coordinates[index])
pred[index, n, :, 0:3] = pose_3d.detach()
valid_targets.append(gt_3d[index, pred[index, n, 0, 3].long()])
valid_weights.append(gt_3d_vis[index, pred[index, n, 0,
3].long(), :,
0:1].float())
valid_preds.append(pose_3d)
losses = dict()
if len(valid_preds) > 0:
valid_targets = torch.cat(valid_targets, dim=0)
valid_weights = torch.cat(valid_weights, dim=0)
valid_preds = torch.cat(valid_preds, dim=0)
losses.update(
self.pose_head.get_loss(valid_preds, valid_targets,
valid_weights))
else:
pose_input_cube = feature_maps[0].new_zeros(
batch_size, self.num_joints, *self.sub_cube_size)
coordinates = feature_maps[0].new_zeros(batch_size,
*self.sub_cube_size,
3).view(batch_size, -1, 3)
pseudo_targets = feature_maps[0].new_zeros(batch_size,
self.num_joints, 3)
pseudo_weights = feature_maps[0].new_zeros(batch_size,
self.num_joints, 1)
pose_heatmaps_3d = self.pose_net(pose_input_cube)
pose_3d = self.pose_head(pose_heatmaps_3d, coordinates)
losses.update(
self.pose_head.get_loss(pose_3d, pseudo_targets,
pseudo_weights))
if return_preds:
return pred, losses
else:
return losses
def forward_test(self,
img,
img_metas,
feature_maps=None,
human_candidates=None,
**kwargs):
"""Defines the computation performed at training.
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
feature_maps width: W
feature_maps height: H
volume_length: cubeL
volume_width: cubeW
volume_height: cubeH
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
feature_maps (list(torch.Tensor[NxCxHxW])):
Multi-camera input feature_maps.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
human_candidates (torch.Tensor[NxPx5]):
Human candidates.
Returns:
dict: predicted poses, human centers and sample_id
"""
batch_size, num_candidates, _ = human_candidates.shape
pred = human_candidates.new_zeros(batch_size, num_candidates,
self.num_joints, 5)
pred[:, :, :, 3:] = human_candidates[:, :, None, 3:]
for n in range(num_candidates):
index = pred[:, n, 0, 3] >= 0
num_valid = index.sum()
if num_valid > 0:
pose_input_cube, coordinates \
= self.project_layer(feature_maps,
img_metas,
self.sub_space_size,
human_candidates[:, n, :3],
self.sub_cube_size)
pose_heatmaps_3d = self.pose_net(pose_input_cube)
pose_3d = self.pose_head(pose_heatmaps_3d[index],
coordinates[index])
pred[index, n, :, 0:3] = pose_3d.detach()
return pred
def show_result(self, **kwargs):
"""Visualize the results."""
raise NotImplementedError
def forward_dummy(self, feature_maps, num_candidates=5):
"""Used for computing network FLOPs."""
batch_size, num_channels = feature_maps[0].shape
pose_input_cube = feature_maps[0].new_zeros(batch_size, num_channels,
*self.sub_cube_size)
for n in range(num_candidates):
_ = self.pose_net(pose_input_cube)
@POSENETS.register_module()
class VoxelCenterDetector(BasePose):
"""Detect human center by 3D CNN on voxels.
Please refer to the
`paper <https://arxiv.org/abs/2004.06239>` for details.
Args:
image_size (list): input size of the 2D model.
heatmap_size (list): output size of the 2D model.
space_size (list): Size of the 3D space.
cube_size (list): Size of the input volume to the 3D CNN.
space_center (list): Coordinate of the center of the 3D space.
center_net (ConfigDict): Dictionary to construct the center net.
center_head (ConfigDict): Dictionary to construct the center head.
train_cfg (ConfigDict): Config for training. Default: None.
test_cfg (ConfigDict): Config for testing. Default: None.
"""
def __init__(
self,
image_size,
heatmap_size,
space_size,
cube_size,
space_center,
center_net,
center_head,
train_cfg=None,
test_cfg=None,
):
super(VoxelCenterDetector, self).__init__()
self.project_layer = ProjectLayer(image_size, heatmap_size)
self.center_net = builder.build_backbone(center_net)
self.center_head = builder.build_head(center_head)
self.space_size = space_size
self.cube_size = cube_size
self.space_center = space_center
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def assign2gt(self, center_candidates, gt_centers, gt_num_persons):
""""Assign gt id to each valid human center candidate."""
det_centers = center_candidates[..., :3]
batch_size = center_candidates.shape[0]
cand_num = center_candidates.shape[1]
cand2gt = torch.zeros(batch_size, cand_num)
for i in range(batch_size):
cand = det_centers[i].view(cand_num, 1, -1)
gt = gt_centers[None, i, :gt_num_persons[i]]
dist = torch.sqrt(torch.sum((cand - gt)**2, dim=-1))
min_dist, min_gt = torch.min(dist, dim=-1)
cand2gt[i] = min_gt
cand2gt[i][min_dist > self.train_cfg['dist_threshold']] = -1.0
center_candidates[:, :, 3] = cand2gt
return center_candidates
def forward(self,
img,
img_metas,
return_loss=True,
feature_maps=None,
targets_3d=None):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
return_loss: Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
Returns:
dict: if 'return_loss' is true, then return losses.
Otherwise, return predicted poses
"""
if return_loss:
return self.forward_train(img, img_metas, feature_maps, targets_3d)
else:
return self.forward_test(img, img_metas, feature_maps)
def forward_train(self,
img,
img_metas,
feature_maps=None,
targets_3d=None,
return_preds=False):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
targets_3d (torch.Tensor[NxcubeLxcubeWxcubeH]):
Ground-truth 3D heatmap of human centers.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
return_preds (bool): Whether to return prediction results
Returns:
dict: if 'return_pred' is true, then return losses
and human centers. Otherwise, return losses only
"""
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
self.space_size,
[self.space_center],
self.cube_size)
center_heatmaps_3d = self.center_net(initial_cubes)
center_heatmaps_3d = center_heatmaps_3d.squeeze(1)
center_candidates = self.center_head(center_heatmaps_3d)
device = center_candidates.device
gt_centers = torch.stack([
torch.tensor(img_meta['roots_3d'], device=device)
for img_meta in img_metas
])
gt_num_persons = torch.stack([
torch.tensor(img_meta['num_persons'], device=device)
for img_meta in img_metas
])
center_candidates = self.assign2gt(center_candidates, gt_centers,
gt_num_persons)
losses = dict()
losses.update(
self.center_head.get_loss(center_heatmaps_3d, targets_3d))
if return_preds:
return center_candidates, losses
else:
return losses
def forward_test(self, img, img_metas, feature_maps=None):
"""
Note:
batch_size: N
num_keypoints: K
num_img_channel: C
img_width: imgW
img_height: imgH
heatmaps width: W
heatmaps height: H
Args:
img (list(torch.Tensor[NxCximgHximgW])):
Multi-camera input images to the 2D model.
img_metas (list(dict)):
Information about image, 3D groundtruth and camera parameters.
feature_maps (list(torch.Tensor[NxKxHxW])):
Multi-camera feature_maps.
Returns:
human centers
"""
initial_cubes, _ = self.project_layer(feature_maps, img_metas,
self.space_size,
[self.space_center],
self.cube_size)
center_heatmaps_3d = self.center_net(initial_cubes)
center_heatmaps_3d = center_heatmaps_3d.squeeze(1)
center_candidates = self.center_head(center_heatmaps_3d)
center_candidates[..., 3] = \
(center_candidates[..., 4] >
self.test_cfg['center_threshold']).float() - 1.0
return center_candidates
def show_result(self, **kwargs):
"""Visualize the results."""
raise NotImplementedError
def forward_dummy(self, feature_maps):
"""Used for computing network FLOPs."""
batch_size, num_channels, _, _ = feature_maps[0].shape
initial_cubes = feature_maps[0].new_zeros(batch_size, num_channels,
*self.cube_size)
_ = self.center_net(initial_cubes)