|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import gc |
|
import logging |
|
import math |
|
import time |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.distributed import rpc |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import utils |
|
|
|
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2 |
|
from fairscale.fair_dev.testing.testing import dist_init |
|
from fairscale.nn import Pipe |
|
from fairscale.nn.model_parallel import initialize_model_parallel |
|
|
|
MPI_PORT = 29500 |
|
RPC_PORT = 29501 |
|
|
|
|
|
def get_tensors_by_size_bucket(): |
|
|
|
size_buckets = defaultdict(int) |
|
for obj in gc.get_objects(): |
|
if not isinstance(obj, torch.Tensor): |
|
continue |
|
if obj.device.type == "cuda": |
|
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1 |
|
|
|
return size_buckets |
|
|
|
|
|
def get_device(model, index): |
|
if isinstance(model, DDP): |
|
model = model.module |
|
|
|
if not torch.cuda.is_available(): |
|
return torch.device("cpu") |
|
if hasattr(model, "devices"): |
|
return model.devices[index] |
|
else: |
|
return torch.cuda.current_device() |
|
|
|
|
|
def get_fake_dataloader(lm_dataloader_len, args): |
|
fake_input = {"input": torch.zeros(args.batch_size)} |
|
|
|
class FakeDataset: |
|
def __getitem__(self, index): |
|
return fake_input |
|
|
|
def __len__(self): |
|
return lm_dataloader_len |
|
|
|
return FakeDataset() |
|
|
|
|
|
def train(model_config, model, benchmark_config, model_specs, args): |
|
lm_dataloader, _, _ = utils.get_data_loader(model_config["dataset_info"], args, benchmark_config, model_specs) |
|
criterion = benchmark_config["criterion"] |
|
vocab_size = model_specs["vocab_size"] |
|
optimizer = model_config["optimizer"] |
|
|
|
model.train() |
|
utils.log_number_of_parameters(model) |
|
|
|
total_loss = 0.0 |
|
word_counter = 0 |
|
|
|
optimizer = optimizer(model.parameters()) |
|
|
|
pipe_group = model.group if hasattr(model, "group") else None |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): |
|
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs) |
|
|
|
total_tokens = 0 |
|
total_tokens_per_log_interval = 0 |
|
bptt = 2 |
|
start_time = time.time() |
|
epoch_start_time = 0.0 |
|
|
|
def get_batch(source): |
|
seq_len = len(source) - 1 |
|
data = source[0:seq_len] |
|
target = source[1 : 1 + seq_len] |
|
return data, target |
|
|
|
for i, batch in enumerate(lm_dataloader): |
|
if i == 1: |
|
epoch_start_time = time.time() |
|
|
|
source, target = get_batch(batch) |
|
if args.max_batch and i > args.max_batch: |
|
break |
|
|
|
if i > 0: |
|
total_tokens += source.numel() |
|
|
|
optimizer.zero_grad() |
|
try: |
|
if pipe_group is None or pipe_group.rank() == 0: |
|
tmp = source.to(get_device(model, 0)) |
|
output = model(tmp) |
|
else: |
|
output = model(source) |
|
except Exception as e: |
|
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e |
|
|
|
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: |
|
target = target.to(get_device(model, -1)) |
|
output = output.to(target.device) |
|
loss = criterion(output.view(-1, vocab_size), target.view(-1)) |
|
loss.backward() |
|
del target |
|
else: |
|
model.back_helper(output) |
|
|
|
del output |
|
|
|
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"]) |
|
optimizer.step() |
|
|
|
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: |
|
total_loss += loss.item() |
|
log_interval = 1 |
|
total_tokens_per_log_interval += source.numel() |
|
if i % log_interval == 0 and i > 0: |
|
cur_loss = total_loss / log_interval |
|
elapsed = time.time() - start_time |
|
if dist.get_rank() == dist.get_world_size() - 1: |
|
logging.debug( |
|
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( |
|
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) |
|
) |
|
) |
|
total_tokens_per_log_interval = 0 |
|
total_loss = 0 |
|
start_time = time.time() |
|
|
|
if epoch_start_time != 0: |
|
wps = total_tokens / (time.time() - epoch_start_time) |
|
else: |
|
raise RuntimeError( |
|
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark." |
|
) |
|
if dist.get_rank() == dist.get_world_size() - 1: |
|
return wps, loss.item() |
|
else: |
|
return 0.0, 0.0 |
|
|
|
|
|
|
|
def evaluate(eval_model, data_source, criterion, ntokens): |
|
eval_model.eval() |
|
total_loss = 0.0 |
|
|
|
bptt = 35 |
|
|
|
def get_batch(source, i, bptt): |
|
seq_len = min(bptt, len(source) - 1 - i) |
|
data = source[i : i + seq_len] |
|
target = source[i + 1 : i + 1 + seq_len].view(-1) |
|
return data, target |
|
|
|
with torch.no_grad(): |
|
for i in range(0, data_source.size(0) - 1, bptt): |
|
data, targets = get_batch(data_source, i, bptt) |
|
output = eval_model(data) |
|
output = output.to(targets.device) |
|
output_flat = output.view(-1, ntokens) |
|
total_loss += len(data) * criterion(output_flat, targets).item() |
|
return total_loss / (len(data_source) - 1) |
|
|
|
|
|
def get_number_of_words(data): |
|
return data.size()[0] * data.size()[1] |
|
|
|
|
|
def verify_peak_memory(rank, golden_config, std_dev): |
|
logging.debug( |
|
"Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]) |
|
) |
|
current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"] |
|
golden_ref = golden_config["peak_mem_usage"][rank] |
|
if not current_device_usage < golden_ref * std_dev: |
|
raise RuntimeError( |
|
"Peak memory usage for cuda device {:d} is {:d} which" |
|
"is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref) |
|
) |
|
|
|
|
|
def verify_lm_run(wps, golden_config, args): |
|
"""Verify that words per second for a given benchmark run matches the golden data.""" |
|
|
|
if dist.get_rank() == dist.get_world_size() - 1: |
|
|
|
|
|
logging.info("Throughput(wps) is {:.2f}.".format(wps)) |
|
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): |
|
raise RuntimeError( |
|
"Throughput(wps):{:.2f} is below the golden threshold of an " |
|
"average value of {:.2f} and standard dev of {:.2f}.".format( |
|
wps, golden_config["avg_wps"], golden_config["std_dev_wps"] |
|
) |
|
) |
|
|
|
for i in range(4): |
|
verify_peak_memory(i, golden_config, 1.1) |
|
|
|
|
|
def benchmark_language_model(model_config, model, benchmark_config, model_specs, config_class, args): |
|
golden_config = get_golden_config(args.model_name, config_class, args) |
|
epoch = benchmark_config["epochs"] |
|
start_time = time.time() |
|
if dist.get_rank() == dist.get_world_size() - 1: |
|
logging.debug("-" * 110) |
|
logging.debug("| start of epoch {:1d}".format(epoch)) |
|
logging.debug("-" * 110) |
|
wps, loss = train(model_config, model, benchmark_config, model_specs, args) |
|
elapsed_time = time.time() - start_time |
|
if dist.get_rank() == dist.get_world_size() - 1: |
|
logging.debug("-" * 110) |
|
logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) |
|
logging.debug("-" * 110) |
|
logging.debug("Throughput(wps) is {:.2f}.".format(wps)) |
|
logging.debug( |
|
"Peak allocated bytes on cuda:{}: {:1d}".format( |
|
dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"] |
|
) |
|
) |
|
|
|
if len(model.balance) == 4: |
|
if args.model_name == "lm": |
|
verify_lm_run(wps, golden_config, args) |
|
else: |
|
raise RuntimeError("Unrecognized args.model_name " % args.model_name) |
|
|
|
|
|
def generate_balance(num_devices, num_layers): |
|
balance = [] |
|
layers_assigned = 0 |
|
for i in range(num_devices): |
|
x = (num_layers - layers_assigned) / (num_devices - i) |
|
if x.is_integer(): |
|
balance.append(int(x)) |
|
layers_assigned += x |
|
else: |
|
balance.append(math.ceil(x)) |
|
layers_assigned += math.ceil(x) |
|
return balance |
|
|
|
|
|
def get_golden_config(model_name, config_class, args): |
|
"""Return a dict with the golden data for throughput and memory usage.""" |
|
|
|
if model_name == "lm": |
|
return config_class.get_golden_real_stats() |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def benchmark_single_process(config_class, args): |
|
"""Benchmark a given model using a single process and multiple devices.""" |
|
|
|
init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT) |
|
torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1, init_method=init_method_pgroup) |
|
|
|
num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 |
|
assert num_devices > 0 |
|
utils.init_random_seed(0) |
|
|
|
benchmark_config = utils.create_benchmark_config(args.model_name, config_class) |
|
model_specs = utils.get_model_specs(args.model_name, config_class) |
|
model_config = utils.create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) |
|
model = model_config["model"] |
|
|
|
balance = generate_balance(min(num_devices, 4), len(model)) |
|
pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint) |
|
del model |
|
del model_config["model"] |
|
|
|
if args.dry_run: |
|
train(model_config, pipe_model, benchmark_config, model_specs, args) |
|
else: |
|
benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, config_class, args) |
|
|
|
|
|
def run_worker(rank, world_size, args): |
|
if args.world_size != 0: |
|
world_size = args.world_size |
|
dist_init(rank + args.rank_base, world_size, hostname=args.host) |
|
initialize_model_parallel(1, world_size) |
|
utils.init_random_seed(0) |
|
run_mp_worker(args, world_size) |
|
|
|
rpc.shutdown() |
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = utils.init_args() |
|
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) |
|
|
|
logging.info(f"Running single process benchmark with args: {args}") |
|
benchmark_single_process(lm_wikitext2, args) |
|
|