File size: 5,764 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""Contains all methods relate to iteration."""
import torchtext.data
from onmt.utils.logging import logger
def batch_iter(data, batch_size, batch_size_fn=None, batch_size_multiple=1):
"""Yield elements from data in chunks of batch_size, where each chunk size
is a multiple of batch_size_multiple.
This is an extended version of torchtext.data.batch.
"""
if batch_size_fn is None:
def batch_size_fn(new, count, sofar):
return count
minibatch, size_so_far = [], 0
for ex in data:
minibatch.append(ex)
size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
if size_so_far >= batch_size:
overflowed = 0
if size_so_far > batch_size:
overflowed += 1
if batch_size_multiple > 1:
overflowed += (
(len(minibatch) - overflowed) % batch_size_multiple)
if overflowed == 0:
yield minibatch
minibatch, size_so_far = [], 0
else:
if overflowed == len(minibatch):
logger.warning(
"The batch will be filled until we reach %d,"
"its size may exceed %d tokens"
% (batch_size_multiple, batch_size)
)
else:
yield minibatch[:-overflowed]
minibatch = minibatch[-overflowed:]
size_so_far = 0
for i, ex in enumerate(minibatch):
size_so_far = batch_size_fn(ex, i + 1, size_so_far)
if minibatch:
yield minibatch
def _pool(data, batch_size, batch_size_fn, batch_size_multiple,
sort_key, random_shuffler, pool_factor):
for p in torchtext.data.batch(
data, batch_size * pool_factor,
batch_size_fn=batch_size_fn):
p_batch = list(batch_iter(
sorted(p, key=sort_key),
batch_size,
batch_size_fn=batch_size_fn,
batch_size_multiple=batch_size_multiple))
for b in random_shuffler(p_batch):
yield b
class OrderedIterator(torchtext.data.Iterator):
def __init__(self,
dataset,
batch_size,
pool_factor=1,
batch_size_multiple=1,
yield_raw_example=False,
**kwargs):
super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs)
self.batch_size_multiple = batch_size_multiple
self.yield_raw_example = yield_raw_example
self.dataset = dataset
self.pool_factor = pool_factor
def create_batches(self):
if self.train:
if self.yield_raw_example:
self.batches = batch_iter(
self.data(),
1,
batch_size_fn=None,
batch_size_multiple=1)
else:
self.batches = _pool(
self.data(),
self.batch_size,
self.batch_size_fn,
self.batch_size_multiple,
self.sort_key,
self.random_shuffler,
self.pool_factor)
else:
self.batches = []
for b in batch_iter(
self.data(),
self.batch_size,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple):
self.batches.append(sorted(b, key=self.sort_key))
def __iter__(self):
"""
Extended version of the definition in torchtext.data.Iterator.
Added yield_raw_example behaviour to yield a torchtext.data.Example
instead of a torchtext.data.Batch object.
"""
while True:
self.init_epoch()
for idx, minibatch in enumerate(self.batches):
# fast-forward if loaded from state
if self._iterations_this_epoch > idx:
continue
self.iterations += 1
self._iterations_this_epoch += 1
if self.sort_within_batch:
# NOTE: `rnn.pack_padded_sequence` requires that a
# minibatch be sorted by decreasing order, which
# requires reversing relative to typical sort keys
if self.sort:
minibatch.reverse()
else:
minibatch.sort(key=self.sort_key, reverse=True)
if self.yield_raw_example:
yield minibatch[0]
else:
yield torchtext.data.Batch(
minibatch,
self.dataset,
self.device)
if not self.repeat:
return
def max_tok_len(new, count, sofar):
"""
In token batching scheme, the number of sequences is limited
such that the total number of src/tgt tokens (including padding)
in a batch <= batch_size
"""
# Maintains the longest src and tgt length in the current batch
global max_src_in_batch, max_tgt_in_batch # this is a hack
# Reset current longest length at a new batch (count=1)
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
# Src: [<bos> w1 ... wN <eos>]
max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2)
# Tgt: [w1 ... wM <eos>]
max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt[0]) + 1)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
|