Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
import torch | |
def channel_saturation_penalty_loss(x: torch.Tensor): | |
assert x.shape[1] == 3 | |
r_channel = x[:, 0, :, :] | |
g_channel = x[:, 1, :, :] | |
b_channel = x[:, 2, :, :] | |
channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2) | |
return channel_accumulate.mean() / 3 | |