Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING, Optional, Union | |
from easydict import EasyDict | |
import os | |
import numpy as np | |
from ding.utils import save_file | |
from ding.policy import Policy | |
from ding.framework import task | |
if TYPE_CHECKING: | |
from ding.framework import OnlineRLContext, OfflineRLContext | |
class CkptSaver: | |
""" | |
Overview: | |
The class used to save checkpoint data. | |
""" | |
def __new__(cls, *args, **kwargs): | |
if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)): | |
return task.void() | |
return super(CkptSaver, cls).__new__(cls) | |
def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True): | |
""" | |
Overview: | |
Initialize the `CkptSaver`. | |
Arguments: | |
- policy (:obj:`Policy`): Policy used to save the checkpoint. | |
- save_dir (:obj:`str`): The directory path to save ckpt. | |
- train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. | |
- save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``. | |
""" | |
self.policy = policy | |
self.train_freq = train_freq | |
if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt": | |
self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir)) | |
else: | |
self.prefix = '{}/'.format(os.path.normpath(save_dir)) | |
if not os.path.exists(self.prefix): | |
os.makedirs(self.prefix) | |
self.last_save_iter = 0 | |
self.max_eval_value = -np.inf | |
self.save_finish = save_finish | |
def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: | |
""" | |
Overview: | |
The method used to save checkpoint data. \ | |
The checkpoint data will be saved in a file in following 3 cases: \ | |
- When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \ | |
- When the evaluation episode return is the best so far; \ | |
- When `task.finish` is True. | |
Input of ctx: | |
- train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network. | |
- eval_value (:obj:`float`): The episode return of current iteration. | |
""" | |
# train enough iteration | |
if self.train_freq: | |
if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq: | |
save_file( | |
"{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() | |
) | |
self.last_save_iter = ctx.train_iter | |
# best episode return so far | |
if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: | |
save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) | |
self.max_eval_value = ctx.eval_value | |
# finish | |
if task.finish and self.save_finish: | |
save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) | |