Spaces:
Running
Running
File size: 16,598 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 |
import functools
import math
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from .parallel import ParallelBackendEnum
from .utils.import_utils import is_bitsandbytes_available
class OptimizerWrapper(Stateful):
r"""
Optimizer wrapper that:
- allows step/zero_grad on multiple optimizers needed for virtual pipeline stages
- saves/loading optimizer state_dict at checkpoint
"""
def __init__(
self,
model_parts: List[torch.nn.Module],
optimizer_cls: Type[torch.optim.Optimizer],
optimizer_kwargs: Dict[str, Any],
) -> None:
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = optimizer_kwargs
self.optimizers = []
self.model_parts = model_parts
for model in self.model_parts:
optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
self.optimizers.append(optimizer)
def step(self) -> None:
for optimizer in self.optimizers:
optimizer.step()
def zero_grad(self) -> None:
for optimizer in self.optimizers:
optimizer.zero_grad()
def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model_parts, self.optimizers))
class SchedulerWrapper:
def __init__(
self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int
) -> None:
self.schedulers = []
for optimizer in optimizers:
self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch))
def step(self) -> None:
for scheduler in self.schedulers:
scheduler.step()
def get_last_lr(self) -> List[float]:
# TODO(aryan): look into this later. Currently calling it leads to NCCL hang?????
return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)}
def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict
def get_optimizer(
parallel_backend: ParallelBackendEnum,
name: str,
model_parts: List[torch.nn.Module],
learning_rate: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.95,
beta3: float = 0.999,
epsilon: float = 1e-8,
weight_decay: float = 1e-4,
fused: bool = False,
) -> Union[torch.optim.Optimizer, OptimizerWrapper]:
name = name.lower()
_raise_errors_if_packages_not_available(name)
if name == "adam":
optimizer_cls = torch.optim.Adam
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
"fused": fused,
}
elif name == "adamw":
optimizer_cls = torch.optim.AdamW
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
"fused": fused,
}
elif name == "adam-bnb":
from bitsandbytes.optim import Adam
optimizer_cls = Adam
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif name == "adamw-bnb":
from bitsandbytes.optim import AdamW
optimizer_cls = AdamW
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif name == "adam-bnb-8bit":
from bitsandbytes.optim import Adam8bit
optimizer_cls = Adam8bit
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif name == "adamw-bnb-8bit":
from bitsandbytes.optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
# TODO(aryan): handle bitsandbytes and torchao
else:
raise ValueError(f"Unsupported optimizer: {name}")
if parallel_backend == ParallelBackendEnum.ACCELERATE:
return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs)
elif parallel_backend == ParallelBackendEnum.PTD:
return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs)
def get_optimizer_accelerate(
model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
) -> torch.optim.Optimizer:
params = [param for model in model_parts for param in model.parameters() if param.requires_grad]
optimizer = optimizer_cls(params, **optimizer_kwargs)
return optimizer
def get_optimizer_ptd(
model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
) -> OptimizerWrapper:
return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs)
def get_lr_scheduler(
parallel_backend: ParallelBackendEnum,
name: str,
optimizer: Union[torch.optim.Optimizer, OptimizerWrapper],
step_rules: Optional[str] = None,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
lr_init: float = 1e-3,
lr_end: float = 1e-7,
last_epoch: int = -1,
) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]:
name = name.lower()
if name == "constant":
scheduler_lambda_fn = get_constant_schedule()
elif name == "constant_with_warmup":
scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps)
elif name == "piecewise_constant":
scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules)
elif name == "linear":
scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps)
elif name == "cosine":
scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles)
elif name == "cosine_with_restarts":
scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup(
num_warmup_steps, num_training_steps, num_cycles
)
elif name == "polynomial":
scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup(
num_warmup_steps, num_training_steps, lr_init, lr_end, power
)
else:
raise ValueError(f"Unsupported scheduler: {name}")
if parallel_backend == ParallelBackendEnum.ACCELERATE:
return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch)
elif parallel_backend == ParallelBackendEnum.PTD:
return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch)
def get_lr_scheduler_accelerate(
optimizer: torch.optim.Optimizer,
scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler],
last_epoch: int = -1,
) -> torch.optim.lr_scheduler.LambdaLR:
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)
return scheduler
def get_lr_scheduler_ptd(
optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1
) -> SchedulerWrapper:
return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch)
# ==============================
# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py
# ==============================
def get_constant_schedule() -> Callable[[int], float]:
r"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
"""
def lr_lambda(current_step: int):
return 1.0
return lr_lambda
def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]:
r"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
Args:
num_warmup_steps (`int`):
The number of steps for the warmup phase.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return 1.0
return lr_lambda
def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]:
r"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
step_rules (`string`):
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
steps and multiple 0.005 for the other steps.
"""
rules_dict = {}
rule_list = step_rules.split(",")
for rule_str in rule_list[:-1]:
value_str, steps_str = rule_str.split(":")
steps = int(steps_str)
value = float(value_str)
rules_dict[steps] = value
last_lr_multiple = float(rule_list[-1])
def create_rules_function(rules_dict, last_lr_multiple):
def rule_func(steps: int) -> float:
sorted_steps = sorted(rules_dict.keys())
for i, sorted_step in enumerate(sorted_steps):
if steps < sorted_step:
return rules_dict[sorted_steps[i]]
return last_lr_multiple
return rule_func
rules_func = create_rules_function(rules_dict, last_lr_multiple)
return rules_func
def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]:
r"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return lr_lambda
def get_cosine_schedule_with_warmup(
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
) -> Callable[[int], float]:
r"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_periods (`float`, *optional*, defaults to 0.5):
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
value to 0 following a half-cosine).
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return lr_lambda
def get_cosine_with_hard_restarts_schedule_with_warmup(
num_warmup_steps: int,
num_training_steps: int,
num_cycles: int = 1,
) -> Callable[[int], float]:
r"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
linearly between 0 and the initial lr set in the optimizer.
Args:
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`int`, *optional*, defaults to 1):
The number of hard restarts to use.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
return lr_lambda
def get_polynomial_decay_schedule_with_warmup(
num_warmup_steps: int,
num_training_steps: int,
lr_init: float,
lr_end: float = 1e-7,
power: float = 1.0,
) -> Callable[[int], float]:
r"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
initial lr set in the optimizer.
Args:
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
lr_end (`float`, *optional*, defaults to 1e-7):
The end LR.
power (`float`, *optional*, defaults to 1.0):
Power factor.
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
"""
if not (lr_init > lr_end):
raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step > num_training_steps:
return lr_end / lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = lr_init - lr_end
decay_steps = num_training_steps - num_warmup_steps
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining**power + lr_end
return decay / lr_init # as LambdaLR multiplies by lr_init
return lr_lambda
def _raise_errors_if_packages_not_available(name: str) -> None:
name_split = name.split("-")
if len(name_split) < 2:
return
package_name = name_split[1]
if package_name == "bnb":
if not is_bitsandbytes_available():
raise ImportError(
f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer."
)
|