|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
def get_adamw( |
|
model: torch.nn.Module, |
|
weight_decay: float, |
|
learning_rate: float, |
|
betas: tuple[float, float], |
|
device_type: str, |
|
) -> torch.optim.AdamW: |
|
""" |
|
Create an AdamW optimizer for the given model with specified parameters. |
|
|
|
Args: |
|
model (torch.nn.Module): The model for which the optimizer is created. |
|
weight_decay (float): The weight decay (L2 penalty) for the optimizer. |
|
learning_rate (float): The learning rate for the optimizer. |
|
betas (tuple): Coefficients used for computing running averages of gradient and its square. |
|
device_type (str): The device type ('cuda' or 'cpu') on which the optimizer will operate. |
|
|
|
Returns: |
|
torch.optim.AdamW: The AdamW optimizer configured with the specified parameters. |
|
""" |
|
|
|
param_dict = {pn: p for pn, p in model.named_parameters()} |
|
|
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
optim_groups = [ |
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
{"params": nodecay_params, "weight_decay": 0.0}, |
|
] |
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
print( |
|
f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" |
|
) |
|
print( |
|
f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" |
|
) |
|
|
|
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
|
use_fused = fused_available and device_type == "cuda" |
|
extra_args = dict(fused=True) if use_fused else dict() |
|
optimizer = torch.optim.AdamW( |
|
optim_groups, lr=learning_rate, betas=betas, **extra_args |
|
) |
|
print(f"using fused AdamW: {use_fused}") |
|
|
|
return optimizer |
|
|
|
|
|
def get_optimizer(configs, model: torch.nn.Module) -> torch.optim.Optimizer: |
|
if configs.adam.use_adamw: |
|
optimizer = get_adamw( |
|
model=model, |
|
weight_decay=configs.adam.weight_decay, |
|
learning_rate=configs.adam.lr, |
|
betas=(configs.adam.beta1, configs.adam.beta2), |
|
device_type="cuda" if torch.cuda.is_available() else "cpu", |
|
) |
|
else: |
|
optimizer = torch.optim.Adam( |
|
model.parameters(), |
|
lr=configs.adam.lr, |
|
weight_decay=configs.adam.weight_decay, |
|
betas=(configs.adam.beta1, configs.adam.beta2), |
|
) |
|
return optimizer |
|
|
|
|
|
def is_loss_nan_check(loss: torch.Tensor) -> bool: |
|
"""check the validness of the current loss |
|
|
|
Args: |
|
loss: the loss from the model |
|
|
|
Returns: |
|
bool: if True, loss is not nan or inf |
|
""" |
|
|
|
def is_nan(x): |
|
return torch.isnan(x).any() or torch.isinf(x).any() |
|
|
|
def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM): |
|
if dist.is_initialized(): |
|
dist.all_reduce(tensor, op=op) |
|
return tensor |
|
|
|
nan_flag = torch.tensor( |
|
1.0 if is_nan(loss) else 0.0, |
|
device=loss.device if torch.cuda.is_available() else None, |
|
) |
|
|
|
all_reduce_tensor(nan_flag) |
|
if nan_flag.item() > 0.0: |
|
return True |
|
return False |
|
|