Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from .. import builder | |
from ..builder import POSENETS | |
class MultiTask(nn.Module): | |
"""Multi-task detectors. | |
Args: | |
backbone (dict): Backbone modules to extract feature. | |
heads (list[dict]): heads to output predictions. | |
necks (list[dict] | None): necks to process feature. | |
head2neck (dict{int:int}): head index to neck index. | |
pretrained (str): Path to the pretrained models. | |
""" | |
def __init__(self, | |
backbone, | |
heads, | |
necks=None, | |
head2neck=None, | |
pretrained=None): | |
super().__init__() | |
self.backbone = builder.build_backbone(backbone) | |
if head2neck is None: | |
assert necks is None | |
head2neck = {} | |
self.head2neck = {} | |
for i in range(len(heads)): | |
self.head2neck[i] = head2neck[i] if i in head2neck else -1 | |
self.necks = nn.ModuleList([]) | |
if necks is not None: | |
for neck in necks: | |
self.necks.append(builder.build_neck(neck)) | |
self.necks.append(nn.Identity()) | |
self.heads = nn.ModuleList([]) | |
assert heads is not None | |
for head in heads: | |
assert head is not None | |
self.heads.append(builder.build_head(head)) | |
self.init_weights(pretrained=pretrained) | |
def with_necks(self): | |
"""Check if has keypoint_head.""" | |
return hasattr(self, 'necks') | |
def init_weights(self, pretrained=None): | |
"""Weight initialization for model.""" | |
self.backbone.init_weights(pretrained) | |
if self.with_necks: | |
for neck in self.necks: | |
if hasattr(neck, 'init_weights'): | |
neck.init_weights() | |
for head in self.heads: | |
if hasattr(head, 'init_weights'): | |
head.init_weights() | |
def forward(self, | |
img, | |
target=None, | |
target_weight=None, | |
img_metas=None, | |
return_loss=True, | |
**kwargs): | |
"""Calls either forward_train or forward_test depending on whether | |
return_loss=True. Note this setting will change the expected inputs. | |
When `return_loss=True`, img and img_meta are single-nested (i.e. | |
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta | |
should be double nested (i.e. List[Tensor], List[List[dict]]), with | |
the outer list indicating test time augmentations. | |
Note: | |
- batch_size: N | |
- num_keypoints: K | |
- num_img_channel: C (Default: 3) | |
- img height: imgH | |
- img weight: imgW | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
img (torch.Tensor[N,C,imgH,imgW]): Input images. | |
target (list[torch.Tensor]): Targets. | |
target_weight (List[torch.Tensor]): Weights. | |
img_metas (list(dict)): Information about data augmentation | |
By default this includes: | |
- "image_file: path to the image file | |
- "center": center of the bbox | |
- "scale": scale of the bbox | |
- "rotation": rotation of the bbox | |
- "bbox_score": score of bbox | |
return_loss (bool): Option to `return loss`. `return loss=True` | |
for training, `return loss=False` for validation & test. | |
Returns: | |
dict|tuple: if `return loss` is true, then return losses. \ | |
Otherwise, return predicted poses, boxes, image paths \ | |
and heatmaps. | |
""" | |
if return_loss: | |
return self.forward_train(img, target, target_weight, img_metas, | |
**kwargs) | |
return self.forward_test(img, img_metas, **kwargs) | |
def forward_train(self, img, target, target_weight, img_metas, **kwargs): | |
"""Defines the computation performed at every call when training.""" | |
features = self.backbone(img) | |
outputs = [] | |
for head_id, head in enumerate(self.heads): | |
neck_id = self.head2neck[head_id] | |
outputs.append(head(self.necks[neck_id](features))) | |
# if return loss | |
losses = dict() | |
for head, output, gt, gt_weight in zip(self.heads, outputs, target, | |
target_weight): | |
loss = head.get_loss(output, gt, gt_weight) | |
assert len(set(losses.keys()).intersection(set(loss.keys()))) == 0 | |
losses.update(loss) | |
if hasattr(head, 'get_accuracy'): | |
acc = head.get_accuracy(output, gt, gt_weight) | |
assert len(set(losses.keys()).intersection(set( | |
acc.keys()))) == 0 | |
losses.update(acc) | |
return losses | |
def forward_test(self, img, img_metas, **kwargs): | |
"""Defines the computation performed at every call when testing.""" | |
assert img.size(0) == len(img_metas) | |
batch_size, _, img_height, img_width = img.shape | |
if batch_size > 1: | |
assert 'bbox_id' in img_metas[0] | |
results = {} | |
features = self.backbone(img) | |
outputs = [] | |
for head_id, head in enumerate(self.heads): | |
neck_id = self.head2neck[head_id] | |
if hasattr(head, 'inference_model'): | |
head_output = head.inference_model( | |
self.necks[neck_id](features), flip_pairs=None) | |
else: | |
head_output = head( | |
self.necks[neck_id](features)).detach().cpu().numpy() | |
outputs.append(head_output) | |
for head, output in zip(self.heads, outputs): | |
result = head.decode( | |
img_metas, output, img_size=[img_width, img_height]) | |
results.update(result) | |
return results | |
def forward_dummy(self, img): | |
"""Used for computing network FLOPs. | |
See ``tools/get_flops.py``. | |
Args: | |
img (torch.Tensor): Input image. | |
Returns: | |
list[Tensor]: Outputs. | |
""" | |
features = self.backbone(img) | |
outputs = [] | |
for head_id, head in enumerate(self.heads): | |
neck_id = self.head2neck[head_id] | |
outputs.append(head(self.necks[neck_id](features))) | |
return outputs | |