Spaces:
Sleeping
Sleeping
File size: 9,412 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import math
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
assert num_heads % world_size == 0
vocab_size = 50264
assert vocab_size % world_size == 0
num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(
n_embd=dim,
n_head=num_heads,
n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module):
if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight)
nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device)
total_nparams = sum(p.numel() for p in model_pt.parameters())
sharded_nparams = sum(p.numel() for p in model.parameters())
sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
shared_nparams = sum(
p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
)
assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == (
(sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights()
with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
loss_pt.backward(g)
loss.backward(g)
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp(
{k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol,
atol=atol * 5,
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.weight.grad,
grad_dict["transformer.ln_f.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
)
for i in range(num_layers):
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].norm1.weight.grad,
grad_dict[f"transformer.layers.{i}.norm1.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm1.bias.grad,
grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
|