Spaces:
Build error
Build error
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import BatchNorm1d, Linear, ReLU | |
from .bert_model import BertForSequenceEncoder | |
from torch.nn import BatchNorm1d, Linear, ReLU | |
from .bert_model import BertForSequenceEncoder | |
from torch.autograd import Variable | |
import numpy as np | |
def kernal_mus(n_kernels): | |
""" | |
get the mu for each guassian kernel. Mu is the middle of each bin | |
:param n_kernels: number of kernels (including exact match). first one is exact match | |
:return: l_mu, a list of mu. | |
""" | |
l_mu = [1] | |
if n_kernels == 1: | |
return l_mu | |
bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] | |
l_mu.append(1 - bin_size / 2) # mu: middle of the bin | |
for i in range(1, n_kernels - 1): | |
l_mu.append(l_mu[i] - bin_size) | |
return l_mu | |
def kernel_sigmas(n_kernels): | |
""" | |
get sigmas for each guassian kernel. | |
:param n_kernels: number of kernels (including exactmath.) | |
:param lamb: | |
:param use_exact: | |
:return: l_sigma, a list of simga | |
""" | |
bin_size = 2.0 / (n_kernels - 1) | |
l_sigma = [0.001] # for exact match. small variance -> exact match | |
if n_kernels == 1: | |
return l_sigma | |
l_sigma += [0.1] * (n_kernels - 1) | |
return l_sigma | |
class inference_model(nn.Module): | |
def __init__(self, bert_model, args): | |
super(inference_model, self).__init__() | |
self.bert_hidden_dim = args.bert_hidden_dim | |
self.dropout = nn.Dropout(args.dropout) | |
self.max_len = args.max_len | |
self.num_labels = args.num_labels | |
self.pred_model = bert_model | |
#self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128) | |
self.proj_match = nn.Linear(self.bert_hidden_dim, 1) | |
def forward(self, inp_tensor, msk_tensor, seg_tensor): | |
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) | |
inputs = self.dropout(inputs) | |
score = self.proj_match(inputs).squeeze(-1) | |
score = torch.tanh(score) | |
return score |