|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Config template to train Retinanet.""" |
|
|
|
from official.modeling.hyperparams import params_dict |
|
from official.vision.detection.configs import base_config |
|
|
|
|
|
|
|
RETINANET_CFG = params_dict.ParamsDict(base_config.BASE_CFG) |
|
RETINANET_CFG.override({ |
|
'type': 'retinanet', |
|
'architecture': { |
|
'parser': 'retinanet_parser', |
|
}, |
|
'retinanet_parser': { |
|
'output_size': [640, 640], |
|
'num_channels': 3, |
|
'match_threshold': 0.5, |
|
'unmatched_threshold': 0.5, |
|
'aug_rand_hflip': True, |
|
'aug_scale_min': 1.0, |
|
'aug_scale_max': 1.0, |
|
'use_autoaugment': False, |
|
'autoaugment_policy_name': 'v0', |
|
'skip_crowd_during_training': True, |
|
'max_num_instances': 100, |
|
}, |
|
'retinanet_head': { |
|
'anchors_per_location': 9, |
|
'num_convs': 4, |
|
'num_filters': 256, |
|
'use_separable_conv': False, |
|
}, |
|
'retinanet_loss': { |
|
'focal_loss_alpha': 0.25, |
|
'focal_loss_gamma': 1.5, |
|
'huber_loss_delta': 0.1, |
|
'box_loss_weight': 50, |
|
}, |
|
'enable_summary': True, |
|
}, is_strict=False) |
|
|
|
RETINANET_RESTRICTIONS = [ |
|
] |
|
|
|
|
|
|