Tianyinus's picture
init submit
edcf5ee verified
raw
history blame
975 Bytes
import torch
from torch import nn
eps = 1e-7
class NCECriterion(nn.Module):
def __init__(self, nLem):
super(NCECriterion, self).__init__()
self.nLem = nLem
def forward(self, x, targets):
batchSize = x.size(0)
K = x.size(1)-1
Pnt = 1 / float(self.nLem)
Pns = 1 / float(self.nLem)
# eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
Pmt = x.select(1,0)
Pmt_div = Pmt.add(K * Pnt + eps)
lnPmt = torch.div(Pmt, Pmt_div)
# eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
Pon_div = x.narrow(1,1,K).add(K * Pns + eps)
Pon = Pon_div.clone().fill_(K * Pns)
lnPon = torch.div(Pon, Pon_div)
# equation 6 in ref. A
lnPmt.log_()
lnPon.log_()
lnPmtsum = lnPmt.sum(0)
lnPonsum = lnPon.view(-1, 1).sum(0)
loss = - (lnPmtsum + lnPonsum) / batchSize
return loss