markytools's picture
added strexp
d61b9c7
raw
history blame
1.06 kB
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, dataset_max_length: int, null_label: int):
super().__init__()
self.max_length = dataset_max_length + 1 # additional stop token
self.null_label = null_label
def _get_length(self, logit, dim=-1):
""" Greed decoder to obtain length from logit"""
out = (logit.argmax(dim=-1) == self.null_label)
abn = out.any(dim)
out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
out = out + 1 # additional end token
out = torch.where(abn, out, out.new_tensor(logit.shape[1], device=out.device))
return out
@staticmethod
def _get_padding_mask(length, max_length):
length = length.unsqueeze(-1)
grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
return grid >= length
@staticmethod
def _get_location_mask(sz, device=None):
mask = torch.eye(sz, device=device)
mask = mask.float().masked_fill(mask == 1, float('-inf'))
return mask