Spaces:
Build error
Build error
""" | |
@date: 2021/7/19 | |
@description: | |
""" | |
import torch | |
import loss | |
from utils.misc import tensor2np | |
def build_criterion(config, logger): | |
criterion = {} | |
device = config.TRAIN.DEVICE | |
for k in config.TRAIN.CRITERION.keys(): | |
sc = config.TRAIN.CRITERION[k] | |
if sc.WEIGHT is None or float(sc.WEIGHT) == 0: | |
continue | |
criterion[sc.NAME] = { | |
'loss': getattr(loss, sc.LOSS)(), | |
'weight': float(sc.WEIGHT), | |
'sub_weights': sc.WEIGHTS, | |
'need_all': sc.NEED_ALL | |
} | |
criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].to(device) | |
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: | |
criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].type(torch.float16) | |
# logger.info(f"Build criterion:{sc.WEIGHT}_{sc.NAME}_{sc.LOSS}_{sc.WEIGHTS}") | |
return criterion | |
def calc_criterion(criterion, gt, dt, epoch_loss_d): | |
loss = None | |
postfix_d = {} | |
for k in criterion.keys(): | |
if criterion[k]['need_all']: | |
single_loss = criterion[k]['loss'](gt, dt) | |
ws_loss = None | |
for i, sub_weight in enumerate(criterion[k]['sub_weights']): | |
if sub_weight == 0: | |
continue | |
if ws_loss is None: | |
ws_loss = single_loss[i] * sub_weight | |
else: | |
ws_loss = ws_loss + single_loss[i] * sub_weight | |
single_loss = ws_loss if ws_loss is not None else single_loss | |
else: | |
assert k in gt.keys(), "ground label is None:" + k | |
assert k in dt.keys(), "detection key is None:" + k | |
if k == 'ratio' and gt[k].shape[-1] != dt[k].shape[-1]: | |
gt[k] = gt[k].repeat(1, dt[k].shape[-1]) | |
single_loss = criterion[k]['loss'](gt[k], dt[k]) | |
postfix_d[k] = tensor2np(single_loss) | |
if k not in epoch_loss_d.keys(): | |
epoch_loss_d[k] = [] | |
epoch_loss_d[k].append(postfix_d[k]) | |
single_loss = single_loss * criterion[k]['weight'] | |
if loss is None: | |
loss = single_loss | |
else: | |
loss = loss + single_loss | |
k = 'loss' | |
postfix_d[k] = tensor2np(loss) | |
if k not in epoch_loss_d.keys(): | |
epoch_loss_d[k] = [] | |
epoch_loss_d[k].append(postfix_d[k]) | |
return loss, postfix_d, epoch_loss_d | |