szukevin's picture
upload
7900c16
import torch
import torch.nn as nn
import torch.distributed as dist
import numpy as np
from tencentpretrain.utils.misc import pooling
class ClrTarget(nn.Module):
"""
"""
def __init__(self, args, vocab_size):
super(ClrTarget, self).__init__()
self.vocab_size = vocab_size
self.batch_size = args.batch_size
self.criterion_0 = nn.CrossEntropyLoss()
self.criterion_1 = nn.CrossEntropyLoss()
self.softmax = nn.LogSoftmax(dim=-1)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pooling_type = [args.stream_0["pooling"], args.stream_1["pooling"]]
if args.projection:
self.projection = True
self.encoder_0_projection = nn.Parameter(torch.randn(args.stream_0["hidden_size"], args.feature_size))
self.encoder_1_projection = nn.Parameter(torch.randn(args.stream_1["hidden_size"], args.feature_size))
else:
self.projection = False
def forward(self, memory_bank, tgt, seg):
"""
Args:
memory_bank: [batch_size x seq_length x hidden_size]
tgt: [batch_size]
Returns:
loss: Classification loss.
correct: Number of sentences that are predicted correctly.
"""
embedding_0, embedding_1 = memory_bank
features_0 = pooling(embedding_0, seg[0], self.pooling_type[0])
features_1 = pooling(embedding_1, seg[1], self.pooling_type[1])
if self.projection:
features_0 = torch.matmul(features_0, self.encoder_0_projection)
features_1 = torch.matmul(features_1, self.encoder_1_projection)
features_0 = features_0 / features_0.norm(dim=-1, keepdim=True)
features_1 = features_1 / features_1.norm(dim=-1, keepdim=True)
# https://github.com/princeton-nlp/SimCSE/blob/main/simcse/models.py#L169
# Gather all embeddings if using distributed training
if dist.is_initialized():
# Dummy vectors for allgather
features_0_list = [torch.zeros_like(features_0) for _ in range(dist.get_world_size())]
features_1_list = [torch.zeros_like(features_1) for _ in range(dist.get_world_size())]
# Allgather
dist.all_gather(tensor_list=features_0_list, tensor=features_0.contiguous())
dist.all_gather(tensor_list=features_1_list, tensor=features_1.contiguous())
# Since allgather results do not have gradients, we replace the
# current process's corresponding embeddings with original tensors
features_0_list[dist.get_rank()] = features_0
features_1_list[dist.get_rank()] = features_1
# Get full batch embeddings: (bs x N, hidden)
features_0 = torch.cat(features_0_list, 0)
features_1 = torch.cat(features_1_list, 0)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_0 = logit_scale * torch.matmul(features_0, features_1.transpose(-2, -1))
logits_1 = logit_scale * torch.matmul(features_1 , features_0.transpose(-2, -1))
tgt = torch.arange(features_0.size()[0], device = logits_0.device, dtype=torch.long)
loss = (self.criterion_0(logits_0, tgt) + self.criterion_1(logits_1, tgt)) / 2
if dist.is_initialized():
correct = self.softmax(logits_0).argmax(dim=-1).eq(tgt).sum() / dist.get_world_size()
else:
correct = self.softmax(logits_0).argmax(dim=-1).eq(tgt).sum()
return loss, correct