|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) |
|
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. |
|
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. |
|
""" |
|
|
|
""" Title """ |
|
|
|
__author__ = "Hiroshi Kajino <[email protected]>" |
|
__copyright__ = "(c) Copyright IBM Corp. 2018" |
|
__version__ = "0.1" |
|
__date__ = "Apr 18 2018" |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import numpy as np |
|
|
|
|
|
def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False): |
|
''' pad left |
|
|
|
Parameters |
|
---------- |
|
sentence_list : list of sequences of integers |
|
max_len : int |
|
maximum length of sentences. |
|
if a sentence is shorter than `max_len`, its left part is padded. |
|
pad_idx : int |
|
integer for padding |
|
inverse : bool |
|
if True, the sequence is inversed. |
|
|
|
Returns |
|
------- |
|
List of torch.LongTensor |
|
each sentence is left-padded. |
|
''' |
|
max_in_list = max([len(each_sen) for each_sen in sentence_list]) |
|
|
|
if max_in_list > max_len: |
|
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) |
|
|
|
if inverse: |
|
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list] |
|
else: |
|
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list] |
|
|
|
|
|
def right_padding(sentence_list, max_len, pad_idx=-1): |
|
''' pad right |
|
|
|
Parameters |
|
---------- |
|
sentence_list : list of sequences of integers |
|
max_len : int |
|
maximum length of sentences. |
|
if a sentence is shorter than `max_len`, its right part is padded. |
|
pad_idx : int |
|
integer for padding |
|
|
|
Returns |
|
------- |
|
List of torch.LongTensor |
|
each sentence is right-padded. |
|
''' |
|
max_in_list = max([len(each_sen) for each_sen in sentence_list]) |
|
if max_in_list > max_len: |
|
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) |
|
|
|
return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list] |
|
|
|
|
|
class HRGDataset(Dataset): |
|
|
|
''' |
|
A class of HRG data |
|
''' |
|
|
|
def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False): |
|
self.hrg = hrg |
|
self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list, |
|
max_len, |
|
inverse=inversed_input) |
|
|
|
self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len) |
|
self.inserved_input = inversed_input |
|
self.target_val_list = target_val_list |
|
if target_val_list is not None: |
|
if len(prod_rule_seq_list) != len(target_val_list): |
|
raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}') |
|
|
|
def __len__(self): |
|
return len(self.left_prod_rule_seq_list) |
|
|
|
def __getitem__(self, idx): |
|
if self.target_val_list is not None: |
|
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx]) |
|
else: |
|
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx] |
|
|
|
@property |
|
def vocab_size(self): |
|
return self.hrg.num_prod_rule |
|
|
|
def batch_padding(each_batch, batch_size, padding_idx): |
|
num_pad = batch_size - len(each_batch[0]) |
|
if num_pad: |
|
each_batch[0] = torch.cat([each_batch[0], |
|
padding_idx * torch.ones((batch_size - len(each_batch[0]), |
|
len(each_batch[0][0])), dtype=torch.int64)], dim=0) |
|
each_batch[1] = torch.cat([each_batch[1], |
|
padding_idx * torch.ones((batch_size - len(each_batch[1]), |
|
len(each_batch[1][0])), dtype=torch.int64)], dim=0) |
|
return each_batch, num_pad |
|
|