File size: 2,310 Bytes
8778cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Utilities for splitting batches of examples into smaller sub-batches.

This is useful during training when the batch size is too large to fit on GPU,
meaning that gradient accumulation across multiple sub-batches must be used.
It is also useful for batching examples during evaluation. Unlike a naive
approach, this code groups examples with similar lengths to reduce the amount
of wasted computation due to padding. 
"""

import numpy as np


def split(*data, costs, max_cost):
    """Splits a batch of input items into sub-batches.

    Args:
        *data: One or more lists of input items, all of the same length
        costs: A list of costs for each item
        max_cost: Maximum total cost for each sub-batch

    Yields:
        (example_ids, *subbatch_data) tuples.
    """
    costs = np.asarray(costs, dtype=int)
    costs_argsort = np.argsort(costs).tolist()

    subbatch_size = 1
    while costs_argsort:
        if subbatch_size == len(costs_argsort) or (
            subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost
        ):
            subbatch_item_ids = costs_argsort[:subbatch_size]
            subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data]
            yield (subbatch_item_ids,) + tuple(subbatch_data)
            costs_argsort = costs_argsort[subbatch_size:]
            subbatch_size = 1
        else:
            subbatch_size += 1


def map(func, *data, costs, max_cost, **common_kwargs):
    """Maps a function over subbatches of input items.

    Args:
        func: Function to map over the data
        *data: One or more lists of input items, all of the same length.
        costs: A list of costs for each item
        max_cost: Maximum total cost for each sub-batch
        **common_kwargs: Keyword arguments to pass to all calls of func

    Returns:
        A list of outputs from calling func(*subbatch_data, **kwargs) for each
        subbatch, and then rearranging the outputs from func into the original
        item order.
    """
    res = [None] * len(data[0])
    for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost):
        subbatch_out = func(*subbatch_items, **common_kwargs)
        for item_id, item_out in zip(item_ids, subbatch_out):
            res[item_id] = item_out
    return res