hjc-owo
init repo
966ae59
raw
history blame contribute delete
439 Bytes
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import torch
def channel_saturation_penalty_loss(x: torch.Tensor):
assert x.shape[1] == 3
r_channel = x[:, 0, :, :]
g_channel = x[:, 1, :, :]
b_channel = x[:, 2, :, :]
channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2)
return channel_accumulate.mean() / 3