File size: 6,057 Bytes
2fdce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from functools import reduce
import logging
import operator

import datasets.wikitext2_data as wikitext2_data
from models import transformer_lm
import numpy as np
import torch
from torch.optim import Adam


def init_random_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)


def init_args():
    parser = argparse.ArgumentParser(description="benchmark")
    parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
    parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
    parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
    parser.add_argument(
        "--checkpoint",
        default="never",
        choices=["always", "except_last", "never"],
        help="Checkpointing strategy for pipe",
    )
    parser.add_argument(
        "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
    )
    parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
    parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
    parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
    parser.add_argument(
        # TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
        "--model_name",
        default="lm",
        help="Language Model(LM) used to benchmark nn.pipe.",
    )
    parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
    args = parser.parse_args()
    return args


def create_benchmark_config(model_name, config_class):
    """Return a dict with configurations required for benchmarking `model_name` model."""

    if model_name == "lm":
        return config_class.get_benchmark_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)


def get_model_specs(model_name, config_class):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return config_class.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % model_name)


def create_model_config(args, benchmark_config=None, model_specs=None, device=None):
    """Return a dict with the given model, dataset and optimizer."""

    if not device:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    dataset_info = get_dataset_info(args)
    assert model_specs is not None
    model_specs["vocab_size"] = dataset_info.ntokens
    model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
    return {
        "model": model,
        "optimizer": optimizer,
        "dataset_info": dataset_info,
    }


def get_model_and_optimizer(args, device, benchmark_config, model_config):
    """Return instantiated model and optimizer function."""

    if args.model_name == "lm":
        model = get_lm_model(args, device, model_config)

    lr = benchmark_config["lr"]

    def make_adam(params):
        return Adam(params, lr=lr)

    optimizer = make_adam
    return model, optimizer


def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
    ndecoder = config["num_decoder_layers"]
    is_moe = config.get("is_moe", False)
    num_local_experts = config.get("num_local_experts", 1)

    if args.lazy_construction:
        layers = [
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
        ]
        for _ in range(ndecoder):
            layers.append(
                LazyModule(
                    lambda: transformer_lm.TransformerDecoderLayer(
                        ninp, nhead, nhid, dropout, is_moe, num_local_experts
                    )
                )
            )

        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
        model = layers
    else:
        model = transformer_lm.TransformerLM(
            vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder, is_moe, num_local_experts
        ).to(device)

    return model


def log_number_of_parameters(model, logger=None):
    if not logger:
        logger = logging
    num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
    if hasattr(model, "group"):
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
        torch.distributed.all_reduce(total, group=model.group)
        logger.debug(
            f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
        )
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logger.debug(f"total #prams = {total.item()}")
    else:
        logger.debug(f"training model, #params = {num_params}")


def get_dataset_info(args):
    assert args.model_name == "lm"
    if args.use_synthetic_data:
        return wikitext2_data.get_synthetic_datasets()
    else:
        return wikitext2_data.get_real_datasets()


def get_data_loader(dataset_info, args, benchmark_config, model_specs, num_replicas=1, rank=0):
    return wikitext2_data.get_dataloaders(dataset_info, benchmark_config, model_specs, num_replicas, rank)