|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from collections import defaultdict |
|
from functools import reduce |
|
import gc |
|
import logging |
|
import math |
|
import operator |
|
import time |
|
|
|
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders |
|
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders |
|
from models import transformer_lm |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim import Adam |
|
|
|
from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2 |
|
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap |
|
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP |
|
|
|
RPC_PORT = 29501 |
|
|
|
|
|
def verify_peak_memory(rank, golden_config, std_dev): |
|
logging.debug( |
|
"Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]) |
|
) |
|
current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"] |
|
golden_ref = golden_config["peak_mem_usage"][rank] |
|
if not current_device_usage < golden_ref * std_dev: |
|
raise RuntimeError( |
|
"Peak memory usage for cuda device {:d} is {:d} which" |
|
"is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref) |
|
) |
|
|
|
|
|
def verify_lm_run(wps, golden_config, args): |
|
"""Verify that words per second for a given benchmark run matches the golden data.""" |
|
|
|
if torch.distributed.get_rank() == 0: |
|
|
|
|
|
logging.info("Throughput(wps) is {:.2f}.".format(wps)) |
|
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): |
|
raise RuntimeError( |
|
"Throughput(wps):{:.2f} is below the golden threshold of an " |
|
"average value of {:.2f} and standard dev of {:.2f}.".format( |
|
wps, golden_config["avg_wps"], golden_config["std_dev_wps"] |
|
) |
|
) |
|
|
|
for i in range(torch.cuda.device_count()): |
|
verify_peak_memory(i, golden_config, 1.1) |
|
|
|
|
|
def init_random_seed(seed: int): |
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
def get_model_and_optimizer(args, device, benchmark_config, model_config): |
|
"""Return instantiated model and optimizer function.""" |
|
|
|
if args.model_name == "lm": |
|
model = get_lm_model(args, device, model_config) |
|
|
|
lr = benchmark_config["lr"] |
|
|
|
def make_adam(params): |
|
return Adam(params, lr=lr) |
|
|
|
optimizer = make_adam |
|
return model, optimizer |
|
|
|
|
|
def get_lm_model(args, device, config): |
|
"""Get language model(based on GPT-2) used for sequence prediction.""" |
|
|
|
ninp = config["ninp"] |
|
nhead = config["nhead"] |
|
initrange = config["initrange"] |
|
dropout = config["dropout"] |
|
vocab_size = config["vocab_size"] |
|
nhid = config["nhid"] |
|
ndecoder = config["num_decoder_layers"] |
|
|
|
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) |
|
|
|
|
|
def get_tensors_by_size_bucket(): |
|
|
|
size_buckets = defaultdict(int) |
|
for obj in gc.get_objects(): |
|
if not isinstance(obj, torch.Tensor): |
|
continue |
|
if obj.device.type == "cuda": |
|
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1 |
|
|
|
return size_buckets |
|
|
|
|
|
def log_number_of_parameters(model): |
|
|
|
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) |
|
if hasattr(model, "group"): |
|
total = torch.Tensor([num_params]) |
|
if torch.cuda.is_available(): |
|
total = total.cuda() |
|
torch.distributed.all_reduce(total, group=model.group) |
|
print( |
|
f"training model, #params = {num_params/10**6}M, group: {model.group.rank()}, grank:" |
|
f" {torch.distributed.get_rank()}, sizes {model.group.size()}" |
|
) |
|
torch.distributed.barrier() |
|
if model.group.rank() == 0: |
|
print(f"total #prams = {total.item()}") |
|
else: |
|
print(f"training model, #params = {num_params/10**6}M") |
|
|
|
|
|
def get_device(model, index): |
|
if isinstance(model, DDP): |
|
model = model.module |
|
|
|
if not torch.cuda.is_available(): |
|
return torch.device("cpu") |
|
if hasattr(model, "devices"): |
|
return model.devices[index] |
|
else: |
|
return torch.cuda.current_device() |
|
|
|
|
|
def get_fake_dataloader(lm_dataloader_len, args): |
|
fake_input = {"input": torch.zeros(args.batch_size)} |
|
|
|
class FakeDataset: |
|
def __getitem__(self, index): |
|
return fake_input |
|
|
|
def __len__(self): |
|
return lm_dataloader_len |
|
|
|
return FakeDataset() |
|
|
|
|
|
def train(model_config, model, benchmark_config, model_specs, args): |
|
lm_dataloader, _, _ = model_config["data"] |
|
criterion = benchmark_config["criterion"] |
|
vocab_size = model_specs["vocab_size"] |
|
optimizer = model_config["optimizer"] |
|
|
|
if not args.benchmark_eval: |
|
model.train() |
|
log_number_of_parameters(model) |
|
|
|
total_loss = 0.0 |
|
word_counter = 0 |
|
|
|
optimizer = optimizer(model.parameters()) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
total_tokens = 0 |
|
total_tokens_per_log_interval = 0 |
|
bptt = 2 |
|
start_time = time.time() |
|
epoch_start_time = 0.0 |
|
|
|
def get_batch(source): |
|
seq_len = len(source) - 1 |
|
data = source[0:seq_len] |
|
target = source[1 : 1 + seq_len] |
|
return data, target |
|
|
|
for i, batch in enumerate(lm_dataloader): |
|
if i == 1: |
|
epoch_start_time = time.time() |
|
|
|
source, target = get_batch(batch) |
|
if args.full_fp16: |
|
|
|
target = target.half() |
|
if args.max_batch and i > args.max_batch: |
|
break |
|
|
|
if i > 0: |
|
total_tokens += source.numel() |
|
|
|
if args.benchmark_eval: |
|
input = source.cuda() |
|
target = target.cuda() |
|
output = model(input) |
|
print(f"output.dtype {output.dtype}, target.dtype {target.dtype}") |
|
loss = torch.nn.CrossEntropyLoss()(output.view(-1, vocab_size), target.view(-1)) |
|
else: |
|
optimizer.zero_grad() |
|
input = source.cuda() |
|
target = target.cuda() |
|
output = model(input) |
|
|
|
loss = criterion(output.view(-1, vocab_size), target.view(-1)) |
|
loss.backward() |
|
|
|
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"]) |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
log_interval = 1 |
|
total_tokens_per_log_interval += source.numel() |
|
if i % log_interval == 0 and i > 0: |
|
cur_loss = total_loss / log_interval |
|
elapsed = time.time() - start_time |
|
if dist.get_rank() == 0: |
|
print( |
|
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( |
|
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) |
|
) |
|
) |
|
total_tokens_per_log_interval = 0 |
|
total_loss = 0 |
|
start_time = time.time() |
|
|
|
if epoch_start_time != 0: |
|
torch.cuda.synchronize() |
|
wps = total_tokens / (time.time() - epoch_start_time) |
|
else: |
|
raise RuntimeError( |
|
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark." |
|
) |
|
return wps, loss.item() |
|
|
|
|
|
def get_number_of_words(data): |
|
return data.size()[0] * data.size()[1] |
|
|
|
|
|
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args): |
|
golden_config = get_golden_config(args.model_name, args) |
|
epoch = benchmark_config["epochs"] |
|
start_time = time.time() |
|
if dist.get_rank() == 0: |
|
print("-" * 110) |
|
print("| start of epoch {:1d}".format(epoch)) |
|
print("-" * 110) |
|
wps, loss = train(model_config, model, benchmark_config, model_specs, args) |
|
elapsed_time = time.time() - start_time |
|
if dist.get_rank() == 0: |
|
print("-" * 110) |
|
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) |
|
print("-" * 110) |
|
print("Throughput(wps) is {:.2f}.".format(wps)) |
|
print( |
|
"Peak allocated bytes on cuda:{}: {:4f}GB".format( |
|
dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"] / 2**30 |
|
) |
|
) |
|
|
|
verify_lm_run(wps, golden_config, args) |
|
|
|
|
|
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs): |
|
"""Returns dataloader for synthetic data.""" |
|
|
|
if args.model_name == "lm": |
|
return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs) |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def get_real_dataloaders(args, device, benchmark_config, model_specs): |
|
"""Returns dataloaders for real data.""" |
|
|
|
if args.model_name == "lm": |
|
data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs) |
|
ntokens, train_dataloader, valid_dataloader, test_dataloader = data |
|
model_specs["vocab_size"] = ntokens |
|
return train_dataloader, valid_dataloader, test_dataloader |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def create_model_config(args, benchmark_config=None, model_specs=None): |
|
"""Return a dict with the given model, dataset and optimizer.""" |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
if args.use_synthetic_data: |
|
dataloader_fn = get_synthetic_dataloaders |
|
else: |
|
dataloader_fn = get_real_dataloaders |
|
|
|
data = dataloader_fn(args, device, benchmark_config, model_specs) |
|
model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs) |
|
return { |
|
"model": model, |
|
"optimizer": optimizer, |
|
"data": data, |
|
} |
|
|
|
|
|
def create_benchmark_config(model_name): |
|
"""Return a dict with configurations required for benchmarking `model_name` model.""" |
|
|
|
if model_name == "lm": |
|
return lm_wikitext2.get_benchmark_config() |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def get_model_specs(model_name): |
|
"""Return a dict with configurations required for configuring `model_name` model.""" |
|
|
|
if model_name == "lm": |
|
return lm_wikitext2.get_model_config() |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def get_golden_config(model_name, args): |
|
"""Return a dict with the golden data for throughput and memory usage.""" |
|
|
|
if model_name == "lm": |
|
return lm_wikitext2.get_golden_synthetic_stats() |
|
else: |
|
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) |
|
|
|
|
|
def benchmark_fsdp(rank, args, world_size): |
|
"""Benchmark a given model using a single process and multiple devices.""" |
|
|
|
init_method_pgroup = "tcp://localhost:{}".format(RPC_PORT) |
|
torch.distributed.init_process_group( |
|
backend="nccl", rank=rank, world_size=world_size, init_method=init_method_pgroup |
|
) |
|
|
|
torch.cuda.set_device(rank) |
|
init_random_seed(0) |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
benchmark_config = create_benchmark_config(args.model_name) |
|
model_specs = get_model_specs(args.model_name) |
|
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) |
|
model = model_config["model"] |
|
config = {} |
|
|
|
if args.full_fp16: |
|
config["compute_dtype"] = torch.float16 |
|
config["mixed_precision"] = False |
|
|
|
if args.enable_auto_wrap: |
|
with enable_wrap(wrapper_cls=FSDP, **config): |
|
fsdp_model = auto_wrap(model, auto_wrap_policy=default_auto_wrap_policy) |
|
fsdp_model = FSDP(fsdp_model, **config) |
|
else: |
|
fsdp_model = FSDP(model, **config) |
|
|
|
if args.full_fp16: |
|
fsdp_model = fsdp_model.half() |
|
print(f"param dtype {[p.dtype for p in fsdp_model.parameters()]}") |
|
if args.dry_run: |
|
train(model_config, fsdp_model, benchmark_config, model_specs, args) |
|
else: |
|
benchmark_language_model(model_config, fsdp_model, benchmark_config, model_specs, args) |
|
|
|
|
|
parser = argparse.ArgumentParser(description="benchmark") |
|
parser.add_argument("--max_batch", type=int, default=4, help="Max number of batches") |
|
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") |
|
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") |
|
parser.add_argument( |
|
"--model_name", |
|
default="lm", |
|
help="Language Model(LM) used to benchmark FSDP.", |
|
) |
|
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") |
|
parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP") |
|
parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.") |
|
parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.") |
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
print(f"Running FSDP benchmark with args: {args}") |
|
num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 |
|
assert num_devices > 0 |
|
|
|
mp.spawn( |
|
benchmark_fsdp, |
|
args=(args, num_devices), |
|
nprocs=num_devices, |
|
join=True, |
|
) |
|
|