llm-studio / llm_studio /src /optimizers.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
raw
history blame contribute delete
800 Bytes
from functools import partial
from typing import Any, List
import bitsandbytes as bnb
from torch import optim
class Optimizers:
"""Optimizers factory."""
_optimizers = {
"Adam": optim.Adam,
"AdamW": optim.AdamW,
"SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
"RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
"Adadelta": optim.Adadelta,
"AdamW8bit": bnb.optim.Adam8bit,
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._optimizers.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Optimizers.
Args:
name: optimizer name
Returns:
A class to build the Optimizer
"""
return cls._optimizers.get(name)