mbp / UltraFlow /losses /losses.py
jiaxianustc's picture
test
e749e85
import numpy as np
import torch
import torch.nn as nn
from UltraFlow import layers
# margin ranking loss
class pair_wise_ranking_loss(nn.Module):
def __init__(self, config):
super(pair_wise_ranking_loss, self).__init__()
self.config = config
self.threshold_filter = nn.Threshold(0.2, 0)
self.score_predict = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, 1)
def ranking_loss(self, z_A, z_B, relation):
"""
loss for a given set of pixels:
z_A: predicted absolute depth for pixels A
z_B: predicted absolute depth for pixels B
relation: -1, 0, 1
"""
pred_depth = z_A - z_B
log_loss = torch.mean(torch.log(1 + torch.exp(-relation[relation != 0] * pred_depth[relation != 0])))
return log_loss
@torch.no_grad()
def get_rank_relation(self, y_A, y_B):
pred_depth = y_A - y_B
pred_depth[self.threshold_filter(pred_depth.abs()) == 0] = 0
return pred_depth.sign()
def forward(self, output_embedding, target):
batch_repeat_num = len(output_embedding)
batch_size = batch_repeat_num // 2
score_predict = self.score_predict(output_embedding)
x_A, y_A, x_B, y_B = score_predict[:batch_size], target[:batch_size], score_predict[batch_size:], target[batch_size:]
relation = self.get_rank_relation(y_A, y_B)
ranking_loss = self.ranking_loss(x_A, x_B, relation)
relation_pred = self.get_rank_relation(x_A, x_B)
return ranking_loss, relation.squeeze(), relation_pred.squeeze()
# binary cross entropy loss
class pair_wise_ranking_loss_v2(nn.Module):
def __init__(self, config):
super(pair_wise_ranking_loss_v2, self).__init__()
self.config = config
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
self.loss_fn = nn.CrossEntropyLoss()
self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
self.m = nn.Softmax(dim=1)
@torch.no_grad()
def get_rank_relation(self, y_A, y_B):
# y_A: [batch, 1]
# target_relation: 0: <=, 1: >
target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
target_relation[(y_A - y_B) > 0.0] = 1
return target_relation.squeeze()
def forward(self, output_embedding, target, assay_des):
batch_repeat_num = len(output_embedding)
batch_size = batch_repeat_num // 2
x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
output_embedding[batch_size:], target[batch_size:]
relation = self.get_rank_relation(y_A, y_B)
if self.pretrain_use_assay_description:
assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
agg_A = x_A + assay_A
agg_B = x_B + assay_B
relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
else:
relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))
ranking_loss = self.loss_fn(relation_pred, relation)
_, y_pred = self.m(relation_pred).max(dim=1)
return ranking_loss, relation.squeeze(), y_pred
# binary cross entropy loss
class pairwise_BCE_loss(nn.Module):
def __init__(self, config):
super(pairwise_BCE_loss, self).__init__()
self.config = config
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
self.loss_fn = nn.CrossEntropyLoss(reduce=False)
if config.model.readout.startswith('multi_head') and config.model.attn_merge == 'concat':
self.relation_mlp = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1) * 2, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
else:
self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
self.m = nn.Softmax(dim=1)
@torch.no_grad()
def get_rank_relation(self, y_A, y_B):
# y_A: [batch, 1]
# target_relation: 0: <=, 1: >
target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
target_relation[(y_A - y_B) > 0.0] = 1
return target_relation.squeeze()
def forward(self, output_embedding, target, assay_des):
batch_repeat_num = len(output_embedding)
batch_size = batch_repeat_num // 2
x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
output_embedding[batch_size:], target[batch_size:]
relation = self.get_rank_relation(y_A, y_B)
if self.pretrain_use_assay_description:
assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
agg_A = x_A + assay_A
agg_B = x_B + assay_B
relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
else:
relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))
ranking_loss = self.loss_fn(relation_pred, relation)
_, y_pred = self.m(relation_pred).max(dim=1)
return ranking_loss, relation.squeeze(), y_pred