import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import BatchNorm1d, Linear, ReLU from .bert_model import BertForSequenceEncoder from torch.nn import BatchNorm1d, Linear, ReLU from .bert_model import BertForSequenceEncoder from torch.autograd import Variable import numpy as np def kernal_mus(n_kernels): """ get the mu for each guassian kernel. Mu is the middle of each bin :param n_kernels: number of kernels (including exact match). first one is exact match :return: l_mu, a list of mu. """ l_mu = [1] if n_kernels == 1: return l_mu bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] l_mu.append(1 - bin_size / 2) # mu: middle of the bin for i in range(1, n_kernels - 1): l_mu.append(l_mu[i] - bin_size) return l_mu def kernel_sigmas(n_kernels): """ get sigmas for each guassian kernel. :param n_kernels: number of kernels (including exactmath.) :param lamb: :param use_exact: :return: l_sigma, a list of simga """ bin_size = 2.0 / (n_kernels - 1) l_sigma = [0.001] # for exact match. small variance -> exact match if n_kernels == 1: return l_sigma l_sigma += [0.1] * (n_kernels - 1) return l_sigma class inference_model(nn.Module): def __init__(self, bert_model, args): super(inference_model, self).__init__() self.bert_hidden_dim = args.bert_hidden_dim self.dropout = nn.Dropout(args.dropout) self.max_len = args.max_len self.num_labels = args.num_labels self.pred_model = bert_model #self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128) self.proj_match = nn.Linear(self.bert_hidden_dim, 1) def forward(self, inp_tensor, msk_tensor, seg_tensor): _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) inputs = self.dropout(inputs) score = self.proj_match(inputs).squeeze(-1) score = torch.tanh(score) return score