Spaces:
Running
Running
File size: 439 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import torch
def channel_saturation_penalty_loss(x: torch.Tensor):
assert x.shape[1] == 3
r_channel = x[:, 0, :, :]
g_channel = x[:, 1, :, :]
b_channel = x[:, 2, :, :]
channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2)
return channel_accumulate.mean() / 3
|