jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
2.03 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Linear, ReLU
from .bert_model import BertForSequenceEncoder
from torch.nn import BatchNorm1d, Linear, ReLU
from .bert_model import BertForSequenceEncoder
from torch.autograd import Variable
import numpy as np
def kernal_mus(n_kernels):
"""
get the mu for each guassian kernel. Mu is the middle of each bin
:param n_kernels: number of kernels (including exact match). first one is exact match
:return: l_mu, a list of mu.
"""
l_mu = [1]
if n_kernels == 1:
return l_mu
bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1]
l_mu.append(1 - bin_size / 2) # mu: middle of the bin
for i in range(1, n_kernels - 1):
l_mu.append(l_mu[i] - bin_size)
return l_mu
def kernel_sigmas(n_kernels):
"""
get sigmas for each guassian kernel.
:param n_kernels: number of kernels (including exactmath.)
:param lamb:
:param use_exact:
:return: l_sigma, a list of simga
"""
bin_size = 2.0 / (n_kernels - 1)
l_sigma = [0.001] # for exact match. small variance -> exact match
if n_kernels == 1:
return l_sigma
l_sigma += [0.1] * (n_kernels - 1)
return l_sigma
class inference_model(nn.Module):
def __init__(self, bert_model, args):
super(inference_model, self).__init__()
self.bert_hidden_dim = args.bert_hidden_dim
self.dropout = nn.Dropout(args.dropout)
self.max_len = args.max_len
self.num_labels = args.num_labels
self.pred_model = bert_model
#self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128)
self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
def forward(self, inp_tensor, msk_tensor, seg_tensor):
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
inputs = self.dropout(inputs)
score = self.proj_match(inputs).squeeze(-1)
score = torch.tanh(score)
return score