Spaces:
Build error
Build error
File size: 6,512 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from .. import builder
from ..builder import POSENETS
@POSENETS.register_module()
class MultiTask(nn.Module):
"""Multi-task detectors.
Args:
backbone (dict): Backbone modules to extract feature.
heads (list[dict]): heads to output predictions.
necks (list[dict] | None): necks to process feature.
head2neck (dict{int:int}): head index to neck index.
pretrained (str): Path to the pretrained models.
"""
def __init__(self,
backbone,
heads,
necks=None,
head2neck=None,
pretrained=None):
super().__init__()
self.backbone = builder.build_backbone(backbone)
if head2neck is None:
assert necks is None
head2neck = {}
self.head2neck = {}
for i in range(len(heads)):
self.head2neck[i] = head2neck[i] if i in head2neck else -1
self.necks = nn.ModuleList([])
if necks is not None:
for neck in necks:
self.necks.append(builder.build_neck(neck))
self.necks.append(nn.Identity())
self.heads = nn.ModuleList([])
assert heads is not None
for head in heads:
assert head is not None
self.heads.append(builder.build_head(head))
self.init_weights(pretrained=pretrained)
@property
def with_necks(self):
"""Check if has keypoint_head."""
return hasattr(self, 'necks')
def init_weights(self, pretrained=None):
"""Weight initialization for model."""
self.backbone.init_weights(pretrained)
if self.with_necks:
for neck in self.necks:
if hasattr(neck, 'init_weights'):
neck.init_weights()
for head in self.heads:
if hasattr(head, 'init_weights'):
head.init_weights()
def forward(self,
img,
target=None,
target_weight=None,
img_metas=None,
return_loss=True,
**kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
Note:
- batch_size: N
- num_keypoints: K
- num_img_channel: C (Default: 3)
- img height: imgH
- img weight: imgW
- heatmaps height: H
- heatmaps weight: W
Args:
img (torch.Tensor[N,C,imgH,imgW]): Input images.
target (list[torch.Tensor]): Targets.
target_weight (List[torch.Tensor]): Weights.
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
return_loss (bool): Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
Returns:
dict|tuple: if `return loss` is true, then return losses. \
Otherwise, return predicted poses, boxes, image paths \
and heatmaps.
"""
if return_loss:
return self.forward_train(img, target, target_weight, img_metas,
**kwargs)
return self.forward_test(img, img_metas, **kwargs)
def forward_train(self, img, target, target_weight, img_metas, **kwargs):
"""Defines the computation performed at every call when training."""
features = self.backbone(img)
outputs = []
for head_id, head in enumerate(self.heads):
neck_id = self.head2neck[head_id]
outputs.append(head(self.necks[neck_id](features)))
# if return loss
losses = dict()
for head, output, gt, gt_weight in zip(self.heads, outputs, target,
target_weight):
loss = head.get_loss(output, gt, gt_weight)
assert len(set(losses.keys()).intersection(set(loss.keys()))) == 0
losses.update(loss)
if hasattr(head, 'get_accuracy'):
acc = head.get_accuracy(output, gt, gt_weight)
assert len(set(losses.keys()).intersection(set(
acc.keys()))) == 0
losses.update(acc)
return losses
def forward_test(self, img, img_metas, **kwargs):
"""Defines the computation performed at every call when testing."""
assert img.size(0) == len(img_metas)
batch_size, _, img_height, img_width = img.shape
if batch_size > 1:
assert 'bbox_id' in img_metas[0]
results = {}
features = self.backbone(img)
outputs = []
for head_id, head in enumerate(self.heads):
neck_id = self.head2neck[head_id]
if hasattr(head, 'inference_model'):
head_output = head.inference_model(
self.necks[neck_id](features), flip_pairs=None)
else:
head_output = head(
self.necks[neck_id](features)).detach().cpu().numpy()
outputs.append(head_output)
for head, output in zip(self.heads, outputs):
result = head.decode(
img_metas, output, img_size=[img_width, img_height])
results.update(result)
return results
def forward_dummy(self, img):
"""Used for computing network FLOPs.
See ``tools/get_flops.py``.
Args:
img (torch.Tensor): Input image.
Returns:
list[Tensor]: Outputs.
"""
features = self.backbone(img)
outputs = []
for head_id, head in enumerate(self.heads):
neck_id = self.head2neck[head_id]
outputs.append(head(self.necks[neck_id](features)))
return outputs
|