|
|
|
|
|
from typing import Any, Dict, List, Tuple |
|
from allennlp.training.optimizers import Registrable, Optimizer, make_parameter_groups |
|
from pytorch_pretrained_bert.optimization import BertAdam |
|
import torch |
|
|
|
|
|
@Optimizer.register("bert_adam") |
|
class BertAdamOptimizer(Optimizer, BertAdam): |
|
|
|
def __init__( |
|
self, |
|
model_parameters: List[Tuple[str, torch.nn.Parameter]], |
|
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, |
|
lr: float = 5e-5, |
|
warmup: float = 0.1, |
|
t_total: int = 50000, |
|
schedule: str = 'warmup_linear', |
|
): |
|
super().__init__( |
|
params=make_parameter_groups(model_parameters, parameter_groups), |
|
lr=lr, |
|
warmup=warmup, |
|
t_total=t_total, |
|
schedule=schedule, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|