|
import itertools |
|
import operator |
|
|
|
import attr |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def argsort(items, key=lambda x: x, reverse=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orig_to_sort, sorted_items = zip(*sorted( |
|
enumerate(items), key=lambda x: key(x[1]), reverse=reverse)) |
|
sort_to_orig = tuple( |
|
x[0] for x in sorted( |
|
enumerate(orig_to_sort), key=operator.itemgetter(1))) |
|
return sorted_items, sort_to_orig, orig_to_sort |
|
|
|
|
|
def sort_lists_by_length(lists): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return argsort(lists, key=len, reverse=True) |
|
|
|
|
|
def batch_bounds_for_packing(lengths): |
|
'''Returns how many items in batch have length >= i at step i. |
|
Examples: |
|
[5] -> [1, 1, 1, 1, 1] |
|
[5, 5] -> [2, 2, 2, 2, 2] |
|
[5, 3] -> [2, 2, 2, 1, 1] |
|
[5, 4, 1, 1] -> [4, 2, 2, 2, 1] |
|
''' |
|
|
|
last_length = 0 |
|
count = len(lengths) |
|
result = [] |
|
for i, (length, group) in enumerate(itertools.groupby(reversed(lengths))): |
|
|
|
if i > 0 and length <= last_length: |
|
raise ValueError('lengths must be decreasing and positive') |
|
result.extend([count] * (length - last_length)) |
|
count -= sum(1 for _ in group) |
|
last_length = length |
|
return result |
|
|
|
|
|
def _make_packed_sequence(data, batch_sizes): |
|
return torch.nn.utils.rnn.PackedSequence(data, |
|
torch.LongTensor(batch_sizes)) |
|
|
|
|
|
@attr.s(frozen=True) |
|
class PackedSequencePlus: |
|
ps = attr.ib() |
|
lengths = attr.ib() |
|
sort_to_orig = attr.ib(converter=np.array) |
|
orig_to_sort = attr.ib(converter=np.array) |
|
@lengths.validator |
|
def descending(self, attribute, value): |
|
for x, y in zip(value, value[1:]): |
|
if not x >= y: |
|
raise ValueError('Lengths are not descending: {}'.format(value)) |
|
|
|
def __attrs_post_init__(self): |
|
self.__dict__['cum_batch_sizes'] = np.cumsum([0] + self.ps.batch_sizes[:-1].tolist()).astype(np.int_) |
|
|
|
def apply(self, fn): |
|
return attr.evolve(self, ps=torch.nn.utils.rnn.PackedSequence( |
|
fn(self.ps.data), self.ps.batch_sizes)) |
|
|
|
def with_new_ps(self, ps): |
|
return attr.evolve(self, ps=ps) |
|
|
|
def pad(self, batch_first, others_to_unsort=(), padding_value=0.0): |
|
padded, seq_lengths = torch.nn.utils.rnn.pad_packed_sequence( |
|
self.ps, batch_first=batch_first, padding_value=padding_value) |
|
results = padded[ |
|
self.sort_to_orig], [seq_lengths[i] for i in self.sort_to_orig] |
|
return results + tuple(t[self.sort_to_orig] for t in others_to_unsort) |
|
|
|
def cuda(self): |
|
if self.ps.data.is_cuda: |
|
return self |
|
return self.apply(lambda d: d.cuda()) |
|
|
|
def raw_index(self, orig_batch_idx, seq_idx): |
|
result = np.take(self.cum_batch_sizes, seq_idx) + np.take( |
|
self.sort_to_orig, orig_batch_idx) |
|
if self.ps.data is not None: |
|
assert np.all(result < len(self.ps.data)) |
|
return result |
|
|
|
def select(self, orig_batch_idx, seq_idx=None): |
|
if seq_idx is None: |
|
return self.ps.data[ |
|
self.raw_index(orig_batch_idx, range(self.lengths[self.sort_to_orig[orig_batch_idx]]))] |
|
return self.ps.data[self.raw_index(orig_batch_idx, seq_idx)] |
|
|
|
def select_subseq(self, orig_batch_indices): |
|
lengths = [self.lengths[self.sort_to_orig[i]] for i in |
|
orig_batch_indices] |
|
return self.from_gather( |
|
lengths=lengths, |
|
map_index=self.raw_index, |
|
gather_from_indices=lambda indices: |
|
self.ps.data[torch.LongTensor(indices)]) |
|
|
|
def orig_index(self, raw_idx): |
|
seq_idx = np.searchsorted( |
|
self.cum_batch_sizes, raw_idx, side='right') - 1 |
|
batch_idx = raw_idx - self.cum_batch_sizes[seq_idx] |
|
orig_batch_idx = self.sort_to_orig[batch_idx] |
|
return orig_batch_idx, seq_idx |
|
|
|
def orig_batch_indices(self): |
|
result = [] |
|
for bs in self.ps.batch_sizes: |
|
result.extend(self.orig_to_sort[:bs]) |
|
return np.array(result) |
|
|
|
def orig_lengths(self): |
|
for sort_idx in self.sort_to_orig: |
|
yield self.lengths[sort_idx] |
|
|
|
def expand(self, k): |
|
|
|
|
|
|
|
|
|
|
|
|
|
v = self.ps.data |
|
ps_data = v.unsqueeze(1).repeat(1, k, *( |
|
[1] * (v.dim() - 1))).view(-1, *v.shape[1:]) |
|
batch_sizes = (np.array(self.ps.batch_sizes) * k).tolist() |
|
lengths = np.repeat(self.lengths, k).tolist() |
|
sort_to_orig = [ |
|
exp_i for i in self.sort_to_orig for exp_i in range(i * k, i * k + k) |
|
] |
|
orig_to_sort = [ |
|
exp_i for i in self.orig_to_sort for exp_i in range(i * k, i * k + k) |
|
] |
|
return PackedSequencePlus( |
|
_make_packed_sequence(ps_data, batch_sizes), |
|
lengths, sort_to_orig, orig_to_sort) |
|
|
|
@classmethod |
|
def from_lists(cls, lists, item_shape, device, item_to_tensor): |
|
|
|
result_list = [] |
|
|
|
sorted_lists, sort_to_orig, orig_to_sort = sort_lists_by_length(lists) |
|
lengths = [len(lst) for lst in sorted_lists] |
|
batch_bounds = batch_bounds_for_packing(lengths) |
|
idx = 0 |
|
for i, bound in enumerate(batch_bounds): |
|
for batch_idx, lst in enumerate(sorted_lists[:bound]): |
|
|
|
embed = item_to_tensor(lst[i], batch_idx) |
|
result_list.append(embed) |
|
idx += 1 |
|
|
|
result = torch.stack(result_list, 0) |
|
return cls( |
|
_make_packed_sequence(result, batch_bounds), |
|
lengths, sort_to_orig, orig_to_sort) |
|
|
|
@classmethod |
|
def from_gather(cls, lengths, map_index, gather_from_indices): |
|
sorted_lengths, sort_to_orig, orig_to_sort = argsort(lengths, reverse=True) |
|
batch_bounds = batch_bounds_for_packing(sorted_lengths) |
|
|
|
indices = [] |
|
for seq_idx, bound in enumerate(batch_bounds): |
|
for batch_idx in orig_to_sort[:bound]: |
|
|
|
|
|
assert seq_idx < lengths[batch_idx] |
|
indices.append(map_index(batch_idx, seq_idx)) |
|
result = gather_from_indices(indices) |
|
|
|
return cls( |
|
_make_packed_sequence(result, batch_bounds), |
|
sorted_lengths, sort_to_orig, orig_to_sort) |
|
|
|
@classmethod |
|
def cat_seqs(cls, items): |
|
|
|
batch_size = len(items[0].lengths) |
|
assert all(len(item.lengths) == batch_size for item in items[1:]) |
|
|
|
|
|
unsorted_concat_lengths = np.zeros(batch_size, dtype=np.int) |
|
for item in items: |
|
unsorted_concat_lengths += list(item.orig_lengths()) |
|
|
|
|
|
concat_data = torch.cat([item.ps.data for item in items], dim=0) |
|
concat_data_base_indices = np.cumsum([0] + [item.ps.data.shape[0] for item in items]) |
|
|
|
item_map_per_batch_item = [] |
|
for batch_idx in range(batch_size): |
|
item_map_per_batch_item.append([ |
|
(item_idx, item, i) |
|
for item_idx, item in enumerate(items) |
|
for i in range(item.lengths[item.sort_to_orig[batch_idx]])]) |
|
|
|
def map_index(batch_idx, seq_idx): |
|
item_idx, item, seq_idx_within_item = item_map_per_batch_item[batch_idx][seq_idx] |
|
return concat_data_base_indices[item_idx] + item.raw_index(batch_idx, seq_idx_within_item) |
|
|
|
return cls.from_gather( |
|
lengths=unsorted_concat_lengths, |
|
map_index=map_index, |
|
gather_from_indices=lambda indices: concat_data[torch.LongTensor(indices)]) |
|
|