# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import torch def channel_saturation_penalty_loss(x: torch.Tensor): assert x.shape[1] == 3 r_channel = x[:, 0, :, :] g_channel = x[:, 1, :, :] b_channel = x[:, 2, :, :] channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2) return channel_accumulate.mean() / 3