File size: 3,570 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.distributed as dist
import numpy as np
from tencentpretrain.utils.misc import pooling


class ClrTarget(nn.Module):
    """
    """
    def __init__(self, args, vocab_size):
        super(ClrTarget, self).__init__()
        self.vocab_size = vocab_size
        self.batch_size = args.batch_size

        self.criterion_0 = nn.CrossEntropyLoss()
        self.criterion_1 = nn.CrossEntropyLoss()
        self.softmax = nn.LogSoftmax(dim=-1)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.pooling_type = [args.stream_0["pooling"], args.stream_1["pooling"]]

        if args.projection:
            self.projection = True
            self.encoder_0_projection = nn.Parameter(torch.randn(args.stream_0["hidden_size"], args.feature_size))
            self.encoder_1_projection = nn.Parameter(torch.randn(args.stream_1["hidden_size"], args.feature_size))
        else:
            self.projection = False


    def forward(self, memory_bank, tgt, seg):
        """
        Args:
            memory_bank: [batch_size x seq_length x hidden_size]
            tgt: [batch_size]

        Returns:
            loss: Classification loss.
            correct: Number of sentences that are predicted correctly.
        """
        embedding_0, embedding_1 = memory_bank
        features_0 = pooling(embedding_0, seg[0], self.pooling_type[0])
        features_1 = pooling(embedding_1, seg[1], self.pooling_type[1])
        if self.projection:
            features_0 = torch.matmul(features_0, self.encoder_0_projection)
            features_1 = torch.matmul(features_1, self.encoder_1_projection)

        features_0 = features_0 / features_0.norm(dim=-1, keepdim=True)
        features_1 = features_1 / features_1.norm(dim=-1, keepdim=True)

        # https://github.com/princeton-nlp/SimCSE/blob/main/simcse/models.py#L169
        # Gather all embeddings if using distributed training
        if dist.is_initialized():
            # Dummy vectors for allgather
            features_0_list = [torch.zeros_like(features_0) for _ in range(dist.get_world_size())]
            features_1_list = [torch.zeros_like(features_1) for _ in range(dist.get_world_size())]

            # Allgather
            dist.all_gather(tensor_list=features_0_list, tensor=features_0.contiguous())
            dist.all_gather(tensor_list=features_1_list, tensor=features_1.contiguous())

            # Since allgather results do not have gradients, we replace the
            # current process's corresponding embeddings with original tensors
            features_0_list[dist.get_rank()] = features_0
            features_1_list[dist.get_rank()] = features_1

            # Get full batch embeddings: (bs x N, hidden)
            features_0 = torch.cat(features_0_list, 0)
            features_1 = torch.cat(features_1_list, 0)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_0 = logit_scale * torch.matmul(features_0, features_1.transpose(-2, -1))
        logits_1 = logit_scale * torch.matmul(features_1 , features_0.transpose(-2, -1))


        tgt = torch.arange(features_0.size()[0], device = logits_0.device, dtype=torch.long)
        loss = (self.criterion_0(logits_0, tgt) + self.criterion_1(logits_1, tgt)) / 2
        if dist.is_initialized():
            correct = self.softmax(logits_0).argmax(dim=-1).eq(tgt).sum() / dist.get_world_size()
        else:
            correct = self.softmax(logits_0).argmax(dim=-1).eq(tgt).sum()
        return loss, correct