Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
from UltraFlow import layers | |
# margin ranking loss | |
class pair_wise_ranking_loss(nn.Module): | |
def __init__(self, config): | |
super(pair_wise_ranking_loss, self).__init__() | |
self.config = config | |
self.threshold_filter = nn.Threshold(0.2, 0) | |
self.score_predict = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, 1) | |
def ranking_loss(self, z_A, z_B, relation): | |
""" | |
loss for a given set of pixels: | |
z_A: predicted absolute depth for pixels A | |
z_B: predicted absolute depth for pixels B | |
relation: -1, 0, 1 | |
""" | |
pred_depth = z_A - z_B | |
log_loss = torch.mean(torch.log(1 + torch.exp(-relation[relation != 0] * pred_depth[relation != 0]))) | |
return log_loss | |
def get_rank_relation(self, y_A, y_B): | |
pred_depth = y_A - y_B | |
pred_depth[self.threshold_filter(pred_depth.abs()) == 0] = 0 | |
return pred_depth.sign() | |
def forward(self, output_embedding, target): | |
batch_repeat_num = len(output_embedding) | |
batch_size = batch_repeat_num // 2 | |
score_predict = self.score_predict(output_embedding) | |
x_A, y_A, x_B, y_B = score_predict[:batch_size], target[:batch_size], score_predict[batch_size:], target[batch_size:] | |
relation = self.get_rank_relation(y_A, y_B) | |
ranking_loss = self.ranking_loss(x_A, x_B, relation) | |
relation_pred = self.get_rank_relation(x_A, x_B) | |
return ranking_loss, relation.squeeze(), relation_pred.squeeze() | |
# binary cross entropy loss | |
class pair_wise_ranking_loss_v2(nn.Module): | |
def __init__(self, config): | |
super(pair_wise_ranking_loss_v2, self).__init__() | |
self.config = config | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.loss_fn = nn.CrossEntropyLoss() | |
self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2) | |
self.m = nn.Softmax(dim=1) | |
def get_rank_relation(self, y_A, y_B): | |
# y_A: [batch, 1] | |
# target_relation: 0: <=, 1: > | |
target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device) | |
target_relation[(y_A - y_B) > 0.0] = 1 | |
return target_relation.squeeze() | |
def forward(self, output_embedding, target, assay_des): | |
batch_repeat_num = len(output_embedding) | |
batch_size = batch_repeat_num // 2 | |
x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\ | |
output_embedding[batch_size:], target[batch_size:] | |
relation = self.get_rank_relation(y_A, y_B) | |
if self.pretrain_use_assay_description: | |
assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ] | |
agg_A = x_A + assay_A | |
agg_B = x_B + assay_B | |
relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1)) | |
else: | |
relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1)) | |
ranking_loss = self.loss_fn(relation_pred, relation) | |
_, y_pred = self.m(relation_pred).max(dim=1) | |
return ranking_loss, relation.squeeze(), y_pred | |
# binary cross entropy loss | |
class pairwise_BCE_loss(nn.Module): | |
def __init__(self, config): | |
super(pairwise_BCE_loss, self).__init__() | |
self.config = config | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.loss_fn = nn.CrossEntropyLoss(reduce=False) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge == 'concat': | |
self.relation_mlp = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1) * 2, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2) | |
else: | |
self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2) | |
self.m = nn.Softmax(dim=1) | |
def get_rank_relation(self, y_A, y_B): | |
# y_A: [batch, 1] | |
# target_relation: 0: <=, 1: > | |
target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device) | |
target_relation[(y_A - y_B) > 0.0] = 1 | |
return target_relation.squeeze() | |
def forward(self, output_embedding, target, assay_des): | |
batch_repeat_num = len(output_embedding) | |
batch_size = batch_repeat_num // 2 | |
x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\ | |
output_embedding[batch_size:], target[batch_size:] | |
relation = self.get_rank_relation(y_A, y_B) | |
if self.pretrain_use_assay_description: | |
assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ] | |
agg_A = x_A + assay_A | |
agg_B = x_B + assay_B | |
relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1)) | |
else: | |
relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1)) | |
ranking_loss = self.loss_fn(relation_pred, relation) | |
_, y_pred = self.m(relation_pred).max(dim=1) | |
return ranking_loss, relation.squeeze(), y_pred |