File size: 4,578 Bytes
2fdce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import time

from golden_configs.lm_wikitext2 import MOE as MOEConfig
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import utils

MPI_PORT = 29500


def benchmark_single_process(config_class, args):
    """Benchmark a given model using a single process and multiple devices."""

    world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
    assert world_size > 0
    benchmark_config = utils.create_benchmark_config(args.model_name, config_class)
    model_specs = utils.get_model_specs(args.model_name, config_class)

    mp.spawn(train, args=(world_size, benchmark_config, model_specs, args), nprocs=world_size, join=True)


def train(rank, world_size, benchmark_config, model_specs, args):
    logger = mp.log_to_stderr()
    logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
    utils.init_random_seed(rank)

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    torch.distributed.init_process_group(
        backend="nccl", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )
    logger.info("train, rank={}".format(rank))
    device = torch.device("cuda", rank) if torch.cuda.is_available() else torch.device("cpu")

    criterion = benchmark_config["criterion"]

    model_config = utils.create_model_config(
        args, benchmark_config=benchmark_config, model_specs=model_specs, device=device
    )
    # vocab_size may change in create_model_config() due to input data
    vocab_size = model_specs["vocab_size"]
    model = model_config["model"]
    model.train()
    optimizer = model_config["optimizer"]
    optimizer = optimizer(model.parameters())
    group = model.group if hasattr(model, "group") else None
    utils.log_number_of_parameters(model, logger)

    total_loss = 0.0
    word_counter = 0
    total_tokens = 0
    total_tokens_per_log_interval = 0
    bptt = 2

    total_elapsed = 0.0

    model = DDP(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
    lm_dataloader, _, _ = utils.get_data_loader(
        model_config["dataset_info"], args, benchmark_config, model_specs, num_replicas=world_size, rank=rank
    )

    def get_batch(source):
        seq_len = len(source) - 1
        data = source[0:seq_len]
        target = source[1 : 1 + seq_len]
        return data, target

    for i, batch in enumerate(lm_dataloader):
        if i == 1:
            epoch_start_time = time.time()

        if args.max_batch and i > args.max_batch:
            break

        if i > 0:
            total_tokens += batch.numel()

        start_time = time.time()
        optimizer.zero_grad()
        source, target = get_batch(batch)
        source = source.to(device)
        target = target.to(device)
        try:
            output = model(source.to(device))
            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
            optimizer.step()
        except Exception as e:
            raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

        elapsed = time.time() - start_time
        total_elapsed += elapsed
        log_interval = 1
        total_tokens_per_log_interval += batch.numel()
        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            logger.debug(
                "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
                    i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
                )
            )
            total_tokens_per_log_interval = 0
            total_loss = 0

    wps = total_tokens / total_elapsed

    logger.debug("rank {}, wps: {}".format(rank, wps))
    logger.debug(
        "Peak allocated bytes on cuda:{}: {:1d}".format(
            dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
        )
    )


if __name__ == "__main__":
    args = utils.init_args()
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)

    logging.info(f"Running single process benchmark with args: {args}")
    benchmark_single_process(MOEConfig, args)