File size: 4,578 Bytes
2fdce3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import time
from golden_configs.lm_wikitext2 import MOE as MOEConfig
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import utils
MPI_PORT = 29500
def benchmark_single_process(config_class, args):
"""Benchmark a given model using a single process and multiple devices."""
world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
assert world_size > 0
benchmark_config = utils.create_benchmark_config(args.model_name, config_class)
model_specs = utils.get_model_specs(args.model_name, config_class)
mp.spawn(train, args=(world_size, benchmark_config, model_specs, args), nprocs=world_size, join=True)
def train(rank, world_size, benchmark_config, model_specs, args):
logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
utils.init_random_seed(rank)
init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
torch.distributed.init_process_group(
backend="nccl", rank=rank, world_size=world_size, init_method=init_method_pgroup
)
logger.info("train, rank={}".format(rank))
device = torch.device("cuda", rank) if torch.cuda.is_available() else torch.device("cpu")
criterion = benchmark_config["criterion"]
model_config = utils.create_model_config(
args, benchmark_config=benchmark_config, model_specs=model_specs, device=device
)
# vocab_size may change in create_model_config() due to input data
vocab_size = model_specs["vocab_size"]
model = model_config["model"]
model.train()
optimizer = model_config["optimizer"]
optimizer = optimizer(model.parameters())
group = model.group if hasattr(model, "group") else None
utils.log_number_of_parameters(model, logger)
total_loss = 0.0
word_counter = 0
total_tokens = 0
total_tokens_per_log_interval = 0
bptt = 2
total_elapsed = 0.0
model = DDP(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
lm_dataloader, _, _ = utils.get_data_loader(
model_config["dataset_info"], args, benchmark_config, model_specs, num_replicas=world_size, rank=rank
)
def get_batch(source):
seq_len = len(source) - 1
data = source[0:seq_len]
target = source[1 : 1 + seq_len]
return data, target
for i, batch in enumerate(lm_dataloader):
if i == 1:
epoch_start_time = time.time()
if args.max_batch and i > args.max_batch:
break
if i > 0:
total_tokens += batch.numel()
start_time = time.time()
optimizer.zero_grad()
source, target = get_batch(batch)
source = source.to(device)
target = target.to(device)
try:
output = model(source.to(device))
loss = criterion(output.view(-1, vocab_size), target.view(-1))
total_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
optimizer.step()
except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e
elapsed = time.time() - start_time
total_elapsed += elapsed
log_interval = 1
total_tokens_per_log_interval += batch.numel()
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
logger.debug(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
total_tokens_per_log_interval = 0
total_loss = 0
wps = total_tokens / total_elapsed
logger.debug("rank {}, wps: {}".format(rank, wps))
logger.debug(
"Peak allocated bytes on cuda:{}: {:1d}".format(
dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
)
)
if __name__ == "__main__":
args = utils.init_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
logging.info(f"Running single process benchmark with args: {args}")
benchmark_single_process(MOEConfig, args)
|