coang's picture
Upload 18 files
7d95c60 verified
import torch
import torch.nn.functional as F
from torch.nn import KLDivLoss
def dot_product_scores(q_vectors, ctx_vectors):
"""
calculates q-ctx dot product scores for every row in ctx_vector
:param q_vector:
:param ctx_vector:
:return:
"""
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
return r
def cosine_scores(q_vectors, ctx_vectors):
"""
calculates q-ctx cosine scores for every row in ctx_vector
:param q_vector:
:param ctx_vector:
:return:
"""
r = F.cosine_similarity(q_vectors, ctx_vectors, dim=0)
return r
class BiEncoderNllLoss(object):
def __init__(self,
score_type="dot",
kd_alpha=0.5):
self.score_type = score_type
self.kd_alpha = kd_alpha
self.kd = KLDivLoss(reduction="batchmean", log_target=True)
def calc(
self,
q_vectors,
ctx_vectors,
kd_scores):
"""
Computes nll loss for the given lists of question and ctx vectors.
Return: a tuple of loss value and amount of correct predictions per batch
"""
scores = self.get_scores(q_vectors, ctx_vectors)
kd_scores = F.log_softmax(kd_scores, dim=1)
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
ctx_num = ctx_vectors.size(0)
ctx_per_q = int(ctx_num/q_num)
no_hard = int(ctx_num/q_num - 1)
scores = scores.view(q_num, -1)
pre_scores = torch.randn(q_num, ctx_per_q, requires_grad=True).to("cuda")
#pre_scores = torch.randn(q_num, ctx_per_q).to("cuda")
for i in range(q_num):
ctx_lst = [i]
ctx_lst += [x for x in range((q_num+i* no_hard),(q_num+i* no_hard+ no_hard))]
#subscores = self.get_scores(q_vectors[i], ctx_vectors[ctx_lst])
pre_scores[i] = scores[i,[ctx_lst]]
#pre_scores = scores[:,:ctx_per_q]
positive_idx_per_question = [i for i in range(q_num)]
softmax_scores = F.log_softmax(scores, dim=1)
pre_softmax_scores = F.log_softmax(pre_scores, dim=1)
bi_loss = F.nll_loss(
softmax_scores,
torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction="mean",
)
kd_loss = self.kd(pre_softmax_scores, kd_scores)
max_score, max_idxs = torch.max(softmax_scores, 1)
loss = self.kd_alpha * bi_loss + (1 - self.kd_alpha) * kd_loss
correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
return loss, correct_predictions_count
def get_scores(self, q_vector, ctx_vectors):
f = self.get_similarity_function()
return f(q_vector, ctx_vectors)
def get_similarity_function(self):
if self.score_type == "dot":
return dot_product_scores
else:
return cosine_scores
class BiEncoderDoubleNllLoss(object):
def __init__(self,
score_type="dot",
alpha = 0.5):
self.score_type = score_type
self.alpha = alpha
def calc(
self,
q_vectors,
ctx_vectors):
"""
Computes nll loss for the given lists of question and ctx vectors.
Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
:return: a tuple of loss value and amount of correct predictions per batch
"""
scores = self.get_scores(q_vectors, ctx_vectors)
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
ctx_num = ctx_vectors.size(0)
no_hard = int(ctx_num/q_num - 1)
scores = scores.view(q_num, -1)
positive_idx_per_question = [i for i in range(q_num)]
scores2 = torch.randn(q_num, ctx_num - no_hard).to("cuda")
for i in range(q_num):
hard_neg_idx = [x for x in range((q_num+i* no_hard),(q_num+i* no_hard+ no_hard))]
random_neg = [x for x in range(ctx_num) if x not in hard_neg_idx]
subscores = self.get_scores(q_vectors[i], ctx_vectors[random_neg])
subscores = subscores.view(1,-1)
scores2[i] = subscores
softmax_scores = F.log_softmax(scores, dim=1)
softmax_scores2 = F.log_softmax(scores2, dim=1)
loss1 = F.nll_loss(
softmax_scores,
torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction="mean",
)
loss2 = F.nll_loss(
softmax_scores2,
torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction="mean",
)
loss = self.alpha * loss1 + (1 - self.alpha) * loss2
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
return loss, correct_predictions_count
def get_scores(self, q_vector, ctx_vectors):
f = self.get_similarity_function()
return f(q_vector, ctx_vectors)
def get_similarity_function(self):
if self.score_type == "dot":
return dot_product_scores
else:
return cosine_scores