DeepFake-Videos-Detection / training /loss /region_independent_loss.py
anyantudre's picture
moved from training repo to inference
caa56d6
raw
history blame
2.8 kB
import torch
import torch.nn.functional as F
from detectors.multi_attention_detector import AttentionPooling
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC
@LOSSFUNC.register_module(module_name="region_independent_loss")
class RegionIndependentLoss(AbstractLossClass):
def __init__(self, M, N, alpha, alpha_decay, decay_batch, inter_margin, intra_margin):
super().__init__()
feature_centers = torch.zeros(M, N)
self.register_buffer("feature_centers",
feature_centers.cuda() if torch.cuda.is_available() else feature_centers)
self.alpha = alpha
self.alpha_decay = alpha_decay
self.decay_batch = decay_batch
self.batch_cnt = 0
self.inter_margin = inter_margin
intra_margin = torch.Tensor(intra_margin)
self.register_buffer("intra_margin", intra_margin.cuda() if torch.cuda.is_available() else intra_margin)
self.atp = AttentionPooling()
def forward(self, feature_maps_d, attention_maps, labels):
B, N, H, W = feature_maps_d.size()
B, M, AH, AW = attention_maps.size()
if AH != H or AW != W:
attention_maps = F.interpolate(attention_maps, (H, W), mode='bilinear', align_corners=True)
feature_matrix = self.atp(feature_maps_d, attention_maps)
# Calculate new feature centers. P.s., I don't know why to use no_grad() and detach() for so many times.
feature_centers = self.feature_centers.detach()
new_feature_centers = feature_centers + self.alpha * torch.mean(feature_matrix - feature_centers, dim=0)
new_feature_centers = new_feature_centers.detach()
with torch.no_grad():
self.feature_centers = new_feature_centers
# Calculate intra-class loss
intra_margins = torch.gather(self.intra_margin.repeat(B, 1), dim=1, index=labels.unsqueeze(1))
intra_class_loss = torch.mean(F.relu(torch.norm(feature_matrix - new_feature_centers, dim=-1) - intra_margins))
# Calculate inter-class loss
inter_class_loss = 0
for i in range(M):
for j in range(i + 1, M):
inter_class_loss += F.relu(
self.inter_margin - torch.dist(new_feature_centers[i], new_feature_centers[j]), inplace=False)
inter_class_loss = inter_class_loss / M / self.alpha
# Count batch, this is used to simulate epoch, since alpha cannot be modified based on epoch due to code
# structure. self.alpha should be modified every N batch.
self.batch_cnt += 1
if self.batch_cnt % self.decay_batch == 0:
self.alpha *= self.alpha_decay
return inter_class_loss + intra_class_loss