Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import contextlib | |
from unittest import mock | |
import torch | |
from detectron2.modeling import poolers | |
from detectron2.modeling.proposal_generator import rpn | |
from detectron2.modeling.roi_heads import keypoint_head, mask_head | |
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers | |
from .c10 import ( | |
Caffe2Compatible, | |
Caffe2FastRCNNOutputsInference, | |
Caffe2KeypointRCNNInference, | |
Caffe2MaskRCNNInference, | |
Caffe2ROIPooler, | |
Caffe2RPN, | |
caffe2_fast_rcnn_outputs_inference, | |
caffe2_keypoint_rcnn_inference, | |
caffe2_mask_rcnn_inference, | |
) | |
class GenericMixin: | |
pass | |
class Caffe2CompatibleConverter: | |
""" | |
A GenericUpdater which implements the `create_from` interface, by modifying | |
module object and assign it with another class replaceCls. | |
""" | |
def __init__(self, replaceCls): | |
self.replaceCls = replaceCls | |
def create_from(self, module): | |
# update module's class to the new class | |
assert isinstance(module, torch.nn.Module) | |
if issubclass(self.replaceCls, GenericMixin): | |
# replaceCls should act as mixin, create a new class on-the-fly | |
new_class = type( | |
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), | |
(self.replaceCls, module.__class__), | |
{}, # {"new_method": lambda self: ...}, | |
) | |
module.__class__ = new_class | |
else: | |
# replaceCls is complete class, this allow arbitrary class swap | |
module.__class__ = self.replaceCls | |
# initialize Caffe2Compatible | |
if isinstance(module, Caffe2Compatible): | |
module.tensor_mode = False | |
return module | |
def patch(model, target, updater, *args, **kwargs): | |
""" | |
recursively (post-order) update all modules with the target type and its | |
subclasses, make a initialization/composition/inheritance/... via the | |
updater.create_from. | |
""" | |
for name, module in model.named_children(): | |
model._modules[name] = patch(module, target, updater, *args, **kwargs) | |
if isinstance(model, target): | |
return updater.create_from(model, *args, **kwargs) | |
return model | |
def patch_generalized_rcnn(model): | |
ccc = Caffe2CompatibleConverter | |
model = patch(model, rpn.RPN, ccc(Caffe2RPN)) | |
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) | |
return model | |
def mock_fastrcnn_outputs_inference( | |
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers | |
): | |
with mock.patch.object( | |
box_predictor_type, | |
"inference", | |
autospec=True, | |
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), | |
) as mocked_func: | |
yield | |
if check: | |
assert mocked_func.call_count > 0 | |
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): | |
with mock.patch( | |
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() | |
) as mocked_func: | |
yield | |
if check: | |
assert mocked_func.call_count > 0 | |
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): | |
with mock.patch( | |
"{}.keypoint_rcnn_inference".format(patched_module), | |
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), | |
) as mocked_func: | |
yield | |
if check: | |
assert mocked_func.call_count > 0 | |
class ROIHeadsPatcher: | |
def __init__(self, heads, use_heatmap_max_keypoint): | |
self.heads = heads | |
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint | |
self.previous_patched = {} | |
def mock_roi_heads(self, tensor_mode=True): | |
""" | |
Patching several inference functions inside ROIHeads and its subclasses | |
Args: | |
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor | |
format or not. Default to True. | |
""" | |
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference` | |
# are called inside the same file as BaseXxxHead due to using mock.patch. | |
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ | |
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ | |
mock_ctx_managers = [ | |
mock_fastrcnn_outputs_inference( | |
tensor_mode=tensor_mode, | |
check=True, | |
box_predictor_type=type(self.heads.box_predictor), | |
) | |
] | |
if getattr(self.heads, "keypoint_on", False): | |
mock_ctx_managers += [ | |
mock_keypoint_rcnn_inference( | |
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint | |
) | |
] | |
if getattr(self.heads, "mask_on", False): | |
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] | |
with contextlib.ExitStack() as stack: # python 3.3+ | |
for mgr in mock_ctx_managers: | |
stack.enter_context(mgr) | |
yield | |
def patch_roi_heads(self, tensor_mode=True): | |
self.previous_patched["box_predictor"] = self.heads.box_predictor.inference | |
self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference | |
self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference | |
def patched_fastrcnn_outputs_inference(predictions, proposal): | |
return caffe2_fast_rcnn_outputs_inference( | |
True, self.heads.box_predictor, predictions, proposal | |
) | |
self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference | |
if getattr(self.heads, "keypoint_on", False): | |
def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): | |
return caffe2_keypoint_rcnn_inference( | |
self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances | |
) | |
keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference | |
if getattr(self.heads, "mask_on", False): | |
def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): | |
return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) | |
mask_head.mask_rcnn_inference = patched_mask_rcnn_inference | |
def unpatch_roi_heads(self): | |
self.heads.box_predictor.inference = self.previous_patched["box_predictor"] | |
keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] | |
mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] | |