Spaces:
Runtime error
Runtime error
File size: 29,581 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule, ModuleDict
from mmengine.structures import InstanceData, PixelData
from torch import Tensor
from mmpose.models.utils.tta import flip_heatmaps
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
OptSampleList, Predictions)
from ..base_head import BaseHead
def smooth_heatmaps(heatmaps: Tensor, blur_kernel_size: int) -> Tensor:
"""Smooth the heatmaps by blurring and averaging.
Args:
heatmaps (Tensor): The heatmaps to smooth.
blur_kernel_size (int): The kernel size for blurring the heatmaps.
Returns:
Tensor: The smoothed heatmaps.
"""
smoothed_heatmaps = torch.nn.functional.avg_pool2d(
heatmaps, blur_kernel_size, 1, (blur_kernel_size - 1) // 2)
smoothed_heatmaps = (heatmaps + smoothed_heatmaps) / 2.0
return smoothed_heatmaps
class TruncSigmoid(nn.Sigmoid):
"""A sigmoid activation function that truncates the output to the given
range.
Args:
min (float, optional): The minimum value to clamp the output to.
Defaults to 0.0
max (float, optional): The maximum value to clamp the output to.
Defaults to 1.0
"""
def __init__(self, min: float = 0.0, max: float = 1.0):
super(TruncSigmoid, self).__init__()
self.min = min
self.max = max
def forward(self, input: Tensor) -> Tensor:
"""Computes the truncated sigmoid activation of the input tensor."""
output = torch.sigmoid(input)
output = output.clamp(min=self.min, max=self.max)
return output
class IIAModule(BaseModule):
"""Instance Information Abstraction module introduced in `CID`. This module
extracts the feature representation vectors for each instance.
Args:
in_channels (int): Number of channels in the input feature tensor
out_channels (int): Number of channels of the output heatmaps
clamp_delta (float, optional): A small value that prevents the sigmoid
activation from becoming saturated. Defaults to 1e-4.
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""
def __init__(
self,
in_channels: int,
out_channels: int,
clamp_delta: float = 1e-4,
init_cfg: OptConfigType = None,
):
super().__init__(init_cfg=init_cfg)
self.keypoint_root_conv = build_conv_layer(
dict(
type='Conv2d',
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1))
self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)
def forward(self, feats: Tensor):
heatmaps = self.keypoint_root_conv(feats)
heatmaps = self.sigmoid(heatmaps)
return heatmaps
def _sample_feats(self, feats: Tensor, indices: Tensor) -> Tensor:
"""Extract feature vectors at the specified indices from the input
feature map.
Args:
feats (Tensor): Input feature map.
indices (Tensor): Indices of the feature vectors to extract.
Returns:
Tensor: Extracted feature vectors.
"""
assert indices.dtype == torch.long
if indices.shape[1] == 3:
b, w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
instance_feats = feats[b, :, h, w]
elif indices.shape[1] == 2:
w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
instance_feats = feats[:, :, h, w]
instance_feats = instance_feats.permute(0, 2, 1)
instance_feats = instance_feats.reshape(-1,
instance_feats.shape[-1])
else:
raise ValueError(f'`indices` should have 2 or 3 channels, '
f'but got f{indices.shape[1]}')
return instance_feats
def _hierarchical_pool(self, heatmaps: Tensor) -> Tensor:
"""Conduct max pooling on the input heatmaps with different kernel size
according to the input size.
Args:
heatmaps (Tensor): Input heatmaps.
Returns:
Tensor: Result of hierarchical pooling.
"""
map_size = (heatmaps.shape[-1] + heatmaps.shape[-2]) / 2.0
if map_size > 300:
maxm = torch.nn.functional.max_pool2d(heatmaps, 7, 1, 3)
elif map_size > 200:
maxm = torch.nn.functional.max_pool2d(heatmaps, 5, 1, 2)
else:
maxm = torch.nn.functional.max_pool2d(heatmaps, 3, 1, 1)
return maxm
def forward_train(self, feats: Tensor, instance_coords: Tensor,
instance_imgids: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward pass during training.
Args:
feats (Tensor): Input feature tensor.
instance_coords (Tensor): Coordinates of the instance roots.
instance_imgids (Tensor): Sample indices of each instances
in the batch.
Returns:
Tuple[Tensor, Tensor]: Extracted feature vectors and heatmaps
for the instances.
"""
heatmaps = self.forward(feats)
indices = torch.cat((instance_imgids[:, None], instance_coords), dim=1)
instance_feats = self._sample_feats(feats, indices)
return instance_feats, heatmaps
def forward_test(
self, feats: Tensor, test_cfg: Dict
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
"""Forward pass during testing.
Args:
feats (Tensor): Input feature tensor.
test_cfg (Dict): Testing configuration, including:
- blur_kernel_size (int, optional): Kernel size for blurring
the heatmaps. Defaults to 3.
- max_instances (int, optional): Maximum number of instances
to extract. Defaults to 30.
- score_threshold (float, optional): Minimum score for
extracting an instance. Defaults to 0.01.
- flip_test (bool, optional): Whether to compute the average
of the heatmaps across the batch dimension.
Defaults to False.
Returns:
A tuple of Tensor including extracted feature vectors,
coordinates, and scores of the instances. Any of these can be
empty Tensor if no instances are extracted.
"""
blur_kernel_size = test_cfg.get('blur_kernel_size', 3)
max_instances = test_cfg.get('max_instances', 30)
score_threshold = test_cfg.get('score_threshold', 0.01)
H, W = feats.shape[-2:]
# compute heatmaps
heatmaps = self.forward(feats).narrow(1, -1, 1)
if test_cfg.get('flip_test', False):
heatmaps = heatmaps.mean(dim=0, keepdims=True)
smoothed_heatmaps = smooth_heatmaps(heatmaps, blur_kernel_size)
# decode heatmaps
maximums = self._hierarchical_pool(smoothed_heatmaps)
maximums = torch.eq(maximums, smoothed_heatmaps).float()
maximums = (smoothed_heatmaps * maximums).reshape(-1)
scores, pos_ind = maximums.topk(max_instances, dim=0)
select_ind = (scores > (score_threshold)).nonzero().squeeze(1)
scores, pos_ind = scores[select_ind], pos_ind[select_ind]
# sample feature vectors from feature map
instance_coords = torch.stack((pos_ind % W, pos_ind // W), dim=1)
instance_feats = self._sample_feats(feats, instance_coords)
return instance_feats, instance_coords, scores
class ChannelAttention(nn.Module):
"""Channel-wise attention module introduced in `CID`.
Args:
in_channels (int): The number of channels of the input instance
vectors.
out_channels (int): The number of channels of the transformed instance
vectors.
"""
def __init__(self, in_channels: int, out_channels: int):
super(ChannelAttention, self).__init__()
self.atn = nn.Linear(in_channels, out_channels)
def forward(self, global_feats: Tensor, instance_feats: Tensor) -> Tensor:
"""Applies attention to the channel dimension of the input tensor."""
instance_feats = self.atn(instance_feats).unsqueeze(2).unsqueeze(3)
return global_feats * instance_feats
class SpatialAttention(nn.Module):
"""Spatial-wise attention module introduced in `CID`.
Args:
in_channels (int): The number of channels of the input instance
vectors.
out_channels (int): The number of channels of the transformed instance
vectors.
"""
def __init__(self, in_channels, out_channels):
super(SpatialAttention, self).__init__()
self.atn = nn.Linear(in_channels, out_channels)
self.feat_stride = 4
self.conv = nn.Conv2d(3, 1, 5, 1, 2)
def _get_pixel_coords(self, heatmap_size: Tuple, device: str = 'cpu'):
"""Get pixel coordinates for each element in the heatmap.
Args:
heatmap_size (tuple): Size of the heatmap in (W, H) format.
device (str): Device to put the resulting tensor on.
Returns:
Tensor of shape (batch_size, num_pixels, 2) containing the pixel
coordinates for each element in the heatmap.
"""
w, h = heatmap_size
y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
pixel_coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
pixel_coords = pixel_coords.float().to(device) + 0.5
return pixel_coords
def forward(self, global_feats: Tensor, instance_feats: Tensor,
instance_coords: Tensor) -> Tensor:
"""Perform spatial attention.
Args:
global_feats (Tensor): Tensor containing the global features.
instance_feats (Tensor): Tensor containing the instance feature
vectors.
instance_coords (Tensor): Tensor containing the root coordinates
of the instances.
Returns:
Tensor containing the modulated global features.
"""
B, C, H, W = global_feats.size()
instance_feats = self.atn(instance_feats).reshape(B, C, 1, 1)
feats = global_feats * instance_feats.expand_as(global_feats)
fsum = torch.sum(feats, dim=1, keepdim=True)
pixel_coords = self._get_pixel_coords((W, H), feats.device)
relative_coords = instance_coords.reshape(
-1, 1, 2) - pixel_coords.reshape(1, -1, 2)
relative_coords = relative_coords.permute(0, 2, 1) / 32.0
relative_coords = relative_coords.reshape(B, 2, H, W)
input_feats = torch.cat((fsum, relative_coords), dim=1)
mask = self.conv(input_feats).sigmoid()
return global_feats * mask
class GFDModule(BaseModule):
"""Global Feature Decoupling module introduced in `CID`. This module
extracts the decoupled heatmaps for each instance.
Args:
in_channels (int): Number of channels in the input feature map
out_channels (int): Number of channels of the output heatmaps
for each instance
gfd_channels (int): Number of channels in the transformed feature map
clamp_delta (float, optional): A small value that prevents the sigmoid
activation from becoming saturated. Defaults to 1e-4.
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""
def __init__(
self,
in_channels: int,
out_channels: int,
gfd_channels: int,
clamp_delta: float = 1e-4,
init_cfg: OptConfigType = None,
):
super().__init__(init_cfg=init_cfg)
self.conv_down = build_conv_layer(
dict(
type='Conv2d',
in_channels=in_channels,
out_channels=gfd_channels,
kernel_size=1))
self.channel_attention = ChannelAttention(in_channels, gfd_channels)
self.spatial_attention = SpatialAttention(in_channels, gfd_channels)
self.fuse_attention = build_conv_layer(
dict(
type='Conv2d',
in_channels=gfd_channels * 2,
out_channels=gfd_channels,
kernel_size=1))
self.heatmap_conv = build_conv_layer(
dict(
type='Conv2d',
in_channels=gfd_channels,
out_channels=out_channels,
kernel_size=1))
self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)
def forward(
self,
feats: Tensor,
instance_feats: Tensor,
instance_coords: Tensor,
instance_imgids: Tensor,
) -> Tensor:
"""Extract decoupled heatmaps for each instance.
Args:
feats (Tensor): Input feature maps.
instance_feats (Tensor): Tensor containing the instance feature
vectors.
instance_coords (Tensor): Tensor containing the root coordinates
of the instances.
instance_imgids (Tensor): Sample indices of each instances
in the batch.
Returns:
A tensor containing decoupled heatmaps.
"""
global_feats = self.conv_down(feats)
global_feats = global_feats[instance_imgids]
cond_instance_feats = torch.cat(
(self.channel_attention(global_feats, instance_feats),
self.spatial_attention(global_feats, instance_feats,
instance_coords)),
dim=1)
cond_instance_feats = self.fuse_attention(cond_instance_feats)
cond_instance_feats = torch.nn.functional.relu(cond_instance_feats)
cond_instance_feats = self.heatmap_conv(cond_instance_feats)
heatmaps = self.sigmoid(cond_instance_feats)
return heatmaps
@MODELS.register_module()
class CIDHead(BaseHead):
"""Contextual Instance Decoupling head introduced in `Contextual Instance
Decoupling for Robust Multi-Person Pose Estimation (CID)`_ by Wang et al
(2022). The head is composed of an Instance Information Abstraction (IIA)
module and a Global Feature Decoupling (GFD) module.
Args:
in_channels (int | Sequence[int]): Number of channels in the input
feature map
num_keypoints (int): Number of keypoints
gfd_channels (int): Number of filters in GFD module
max_train_instances (int): Maximum number of instances in a batch
during training. Defaults to 200
heatmap_loss (Config): Config of the heatmap loss. Defaults to use
:class:`KeypointMSELoss`
coupled_heatmap_loss (Config): Config of the loss for coupled heatmaps.
Defaults to use :class:`SoftWeightSmoothL1Loss`
decoupled_heatmap_loss (Config): Config of the loss for decoupled
heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss`
contrastive_loss (Config): Config of the contrastive loss for
representation vectors of instances. Defaults to use
:class:`InfoNCELoss`
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_
Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_
CVPR_2022_paper.html
"""
_version = 2
def __init__(self,
in_channels: Union[int, Sequence[int]],
gfd_channels: int,
num_keypoints: int,
prior_prob: float = 0.01,
coupled_heatmap_loss: OptConfigType = dict(
type='FocalHeatmapLoss'),
decoupled_heatmap_loss: OptConfigType = dict(
type='FocalHeatmapLoss'),
contrastive_loss: OptConfigType = dict(type='InfoNCELoss'),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
if init_cfg is None:
init_cfg = self.default_init_cfg
super().__init__(init_cfg)
self.in_channels = in_channels
self.num_keypoints = num_keypoints
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
else:
self.decoder = None
# build sub-modules
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.iia_module = IIAModule(
in_channels,
num_keypoints + 1,
init_cfg=init_cfg + [
dict(
type='Normal',
layer=['Conv2d', 'Linear'],
std=0.001,
override=dict(
name='keypoint_root_conv',
type='Normal',
std=0.001,
bias=bias_value))
])
self.gfd_module = GFDModule(
in_channels,
num_keypoints,
gfd_channels,
init_cfg=init_cfg + [
dict(
type='Normal',
layer=['Conv2d', 'Linear'],
std=0.001,
override=dict(
name='heatmap_conv',
type='Normal',
std=0.001,
bias=bias_value))
])
# build losses
self.loss_module = ModuleDict(
dict(
heatmap_coupled=MODELS.build(coupled_heatmap_loss),
heatmap_decoupled=MODELS.build(decoupled_heatmap_loss),
contrastive=MODELS.build(contrastive_loss),
))
# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
@property
def default_init_cfg(self):
init_cfg = [
dict(type='Normal', layer=['Conv2d', 'Linear'], std=0.001),
dict(type='Constant', layer='BatchNorm2d', val=1)
]
return init_cfg
def forward(self, feats: Tuple[Tensor]) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the heatmap.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output heatmap.
"""
feats = feats[-1]
instance_info = self.iia_module.forward_test(feats, {})
instance_feats, instance_coords, instance_scores = instance_info
instance_imgids = torch.zeros(
instance_coords.size(0), dtype=torch.long, device=feats.device)
instance_heatmaps = self.gfd_module(feats, instance_feats,
instance_coords, instance_imgids)
return instance_heatmaps
def predict(self,
feats: Features,
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
``test_cfg['output_heatmap']==True``, return both pose and heatmap
prediction; otherwise only return the pose prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
metainfo = batch_data_samples[0].metainfo
if test_cfg.get('flip_test', False):
assert isinstance(feats, list) and len(feats) == 2
feats_flipped = flip_heatmaps(feats[1][-1], shift_heatmap=False)
feats = torch.cat((feats[0][-1], feats_flipped))
else:
feats = feats[-1]
instance_info = self.iia_module.forward_test(feats, test_cfg)
instance_feats, instance_coords, instance_scores = instance_info
if len(instance_coords) > 0:
instance_imgids = torch.zeros(
instance_coords.size(0), dtype=torch.long, device=feats.device)
if test_cfg.get('flip_test', False):
instance_coords = torch.cat((instance_coords, instance_coords))
instance_imgids = torch.cat(
(instance_imgids, instance_imgids + 1))
instance_heatmaps = self.gfd_module(feats, instance_feats,
instance_coords,
instance_imgids)
if test_cfg.get('flip_test', False):
flip_indices = batch_data_samples[0].metainfo['flip_indices']
instance_heatmaps, instance_heatmaps_flip = torch.chunk(
instance_heatmaps, 2, dim=0)
instance_heatmaps_flip = \
instance_heatmaps_flip[:, flip_indices, :, :]
instance_heatmaps = (instance_heatmaps +
instance_heatmaps_flip) / 2.0
instance_heatmaps = smooth_heatmaps(
instance_heatmaps, test_cfg.get('blur_kernel_size', 3))
preds = self.decode((instance_heatmaps, instance_scores[:, None]))
preds = InstanceData.cat(preds)
preds.keypoints[..., 0] += metainfo['input_size'][
0] / instance_heatmaps.shape[-1] / 2.0
preds.keypoints[..., 1] += metainfo['input_size'][
1] / instance_heatmaps.shape[-2] / 2.0
preds = [preds]
else:
preds = [
InstanceData(
keypoints=np.empty((0, self.num_keypoints, 2)),
keypoint_scores=np.empty((0, self.num_keypoints)))
]
instance_heatmaps = torch.empty(0, self.num_keypoints,
*feats.shape[-2:])
if test_cfg.get('output_heatmaps', False):
pred_fields = [
PixelData(
heatmaps=instance_heatmaps.reshape(
-1, *instance_heatmaps.shape[-2:]))
]
return preds, pred_fields
else:
return preds
def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
# load targets
gt_heatmaps, gt_instance_coords, keypoint_weights = [], [], []
heatmap_mask = []
instance_imgids, gt_instance_heatmaps = [], []
for i, d in enumerate(batch_data_samples):
gt_heatmaps.append(d.gt_fields.heatmaps)
gt_instance_coords.append(d.gt_instance_labels.instance_coords)
keypoint_weights.append(d.gt_instance_labels.keypoint_weights)
instance_imgids.append(
torch.ones(
len(d.gt_instance_labels.instance_coords),
dtype=torch.long) * i)
instance_heatmaps = d.gt_fields.instance_heatmaps.reshape(
-1, self.num_keypoints,
*d.gt_fields.instance_heatmaps.shape[1:])
gt_instance_heatmaps.append(instance_heatmaps)
if 'heatmap_mask' in d.gt_fields:
heatmap_mask.append(d.gt_fields.heatmap_mask)
gt_heatmaps = torch.stack(gt_heatmaps)
heatmap_mask = torch.stack(heatmap_mask) if heatmap_mask else None
gt_instance_coords = torch.cat(gt_instance_coords, dim=0)
gt_instance_heatmaps = torch.cat(gt_instance_heatmaps, dim=0)
keypoint_weights = torch.cat(keypoint_weights, dim=0)
instance_imgids = torch.cat(instance_imgids).to(gt_heatmaps.device)
# feed-forward
feats = feats[-1]
pred_instance_feats, pred_heatmaps = self.iia_module.forward_train(
feats, gt_instance_coords, instance_imgids)
# conpute contrastive loss
contrastive_loss = 0
for i in range(len(batch_data_samples)):
pred_instance_feat = pred_instance_feats[instance_imgids == i]
contrastive_loss += self.loss_module['contrastive'](
pred_instance_feat)
contrastive_loss = contrastive_loss / max(1, len(instance_imgids))
# limit the number of instances
max_train_instances = train_cfg.get('max_train_instances', -1)
if (max_train_instances > 0
and len(instance_imgids) > max_train_instances):
selected_indices = torch.randperm(
len(instance_imgids),
device=gt_heatmaps.device,
dtype=torch.long)[:max_train_instances]
gt_instance_coords = gt_instance_coords[selected_indices]
keypoint_weights = keypoint_weights[selected_indices]
gt_instance_heatmaps = gt_instance_heatmaps[selected_indices]
instance_imgids = instance_imgids[selected_indices]
pred_instance_feats = pred_instance_feats[selected_indices]
# calculate the decoupled heatmaps for each instance
pred_instance_heatmaps = self.gfd_module(feats, pred_instance_feats,
gt_instance_coords,
instance_imgids)
# calculate losses
losses = {
'loss/heatmap_coupled':
self.loss_module['heatmap_coupled'](pred_heatmaps, gt_heatmaps,
None, heatmap_mask)
}
if len(instance_imgids) > 0:
losses.update({
'loss/heatmap_decoupled':
self.loss_module['heatmap_decoupled'](pred_instance_heatmaps,
gt_instance_heatmaps,
keypoint_weights),
'loss/contrastive':
contrastive_loss
})
return losses
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to convert old-version state dict of
:class:`CIDHead` (before MMPose v1.0.0) to a compatible format
of :class:`CIDHead`.
The hook will be automatically registered during initialization.
"""
version = local_meta.get('version', None)
if version and version >= self._version:
return
# convert old-version state dict
keys = list(state_dict.keys())
for k in keys:
if 'keypoint_center_conv' in k:
v = state_dict.pop(k)
k = k.replace('keypoint_center_conv',
'iia_module.keypoint_root_conv')
state_dict[k] = v
if 'conv_down' in k:
v = state_dict.pop(k)
k = k.replace('conv_down', 'gfd_module.conv_down')
state_dict[k] = v
if 'c_attn' in k:
v = state_dict.pop(k)
k = k.replace('c_attn', 'gfd_module.channel_attention')
state_dict[k] = v
if 's_attn' in k:
v = state_dict.pop(k)
k = k.replace('s_attn', 'gfd_module.spatial_attention')
state_dict[k] = v
if 'fuse_attn' in k:
v = state_dict.pop(k)
k = k.replace('fuse_attn', 'gfd_module.fuse_attention')
state_dict[k] = v
if 'heatmap_conv' in k:
v = state_dict.pop(k)
k = k.replace('heatmap_conv', 'gfd_module.heatmap_conv')
state_dict[k] = v
|