File size: 5,387 Bytes
e749e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import torch
import torch.nn as nn
from UltraFlow import layers

# margin ranking loss
class pair_wise_ranking_loss(nn.Module):
    def __init__(self, config):
        super(pair_wise_ranking_loss, self).__init__()
        self.config = config
        self.threshold_filter = nn.Threshold(0.2, 0)
        self.score_predict = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, 1)

    def ranking_loss(self, z_A, z_B, relation):
        """
        loss for a given set of pixels:
        z_A: predicted absolute depth for pixels A
        z_B: predicted absolute depth for pixels B
        relation: -1, 0, 1
        """
        pred_depth = z_A - z_B
        log_loss = torch.mean(torch.log(1 + torch.exp(-relation[relation != 0] * pred_depth[relation != 0])))
        return log_loss

    @torch.no_grad()
    def get_rank_relation(self, y_A, y_B):
        pred_depth = y_A - y_B
        pred_depth[self.threshold_filter(pred_depth.abs()) == 0] = 0

        return pred_depth.sign()

    def forward(self, output_embedding, target):
        batch_repeat_num = len(output_embedding)
        batch_size = batch_repeat_num // 2

        score_predict = self.score_predict(output_embedding)
        x_A, y_A, x_B, y_B = score_predict[:batch_size], target[:batch_size], score_predict[batch_size:], target[batch_size:]

        relation = self.get_rank_relation(y_A, y_B)

        ranking_loss = self.ranking_loss(x_A, x_B, relation)

        relation_pred = self.get_rank_relation(x_A, x_B)

        return ranking_loss, relation.squeeze(), relation_pred.squeeze()

# binary cross entropy loss
class pair_wise_ranking_loss_v2(nn.Module):
    def __init__(self, config):
        super(pair_wise_ranking_loss_v2, self).__init__()
        self.config = config
        self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
        self.loss_fn = nn.CrossEntropyLoss()
        self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
        self.m = nn.Softmax(dim=1)

    @torch.no_grad()
    def get_rank_relation(self, y_A, y_B):
        # y_A: [batch, 1]
        # target_relation: 0: <=, 1: >
        target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
        target_relation[(y_A - y_B) > 0.0] = 1

        return target_relation.squeeze()

    def forward(self, output_embedding, target, assay_des):
        batch_repeat_num = len(output_embedding)
        batch_size = batch_repeat_num // 2
        x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
                             output_embedding[batch_size:], target[batch_size:]

        relation = self.get_rank_relation(y_A, y_B)

        if self.pretrain_use_assay_description:
            assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
            agg_A = x_A + assay_A
            agg_B = x_B + assay_B
            relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
        else:
            relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))

        ranking_loss = self.loss_fn(relation_pred, relation)

        _, y_pred = self.m(relation_pred).max(dim=1)

        return ranking_loss, relation.squeeze(), y_pred

# binary cross entropy loss
class pairwise_BCE_loss(nn.Module):
    def __init__(self, config):
        super(pairwise_BCE_loss, self).__init__()
        self.config = config
        self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
        self.loss_fn = nn.CrossEntropyLoss(reduce=False)
        if config.model.readout.startswith('multi_head') and config.model.attn_merge == 'concat':
            self.relation_mlp = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1) * 2, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
        else:
            self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
        self.m = nn.Softmax(dim=1)

    @torch.no_grad()
    def get_rank_relation(self, y_A, y_B):
        # y_A: [batch, 1]
        # target_relation: 0: <=, 1: >
        target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
        target_relation[(y_A - y_B) > 0.0] = 1

        return target_relation.squeeze()

    def forward(self, output_embedding, target, assay_des):
        batch_repeat_num = len(output_embedding)
        batch_size = batch_repeat_num // 2
        x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
                             output_embedding[batch_size:], target[batch_size:]

        relation = self.get_rank_relation(y_A, y_B)

        if self.pretrain_use_assay_description:
            assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
            agg_A = x_A + assay_A
            agg_B = x_B + assay_B
            relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
        else:
            relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))

        ranking_loss = self.loss_fn(relation_pred, relation)

        _, y_pred = self.m(relation_pred).max(dim=1)

        return ranking_loss, relation.squeeze(), y_pred