|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
import typing as tp |
|
|
|
import flashy |
|
import omegaconf |
|
import torch |
|
from torch import nn |
|
|
|
from .. import optim |
|
from ..optim import fsdp |
|
from ..utils import checkpoint |
|
from ..utils.autocast import TorchAutocast |
|
from ..utils.best_state import BestStateDictManager |
|
from ..utils.deadlock import DeadlockDetect |
|
from ..utils.profiler import Profiler |
|
from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng |
|
|
|
|
|
class StandardSolver(ABC, flashy.BaseSolver): |
|
"""Standard solver for AudioCraft. |
|
|
|
The standard solver implements a base training loop with the following stages: |
|
train, valid, evaluate and generate that are expected to be all defined for |
|
solvers in AudioCraft. It also provides a nice default management of Dora history replay, |
|
checkpoint management across epoch, and logging configuration. |
|
|
|
AudioCraft solvers must inherit from the StandardSolver and define the methods |
|
associated to each stage as well as the show, build_model and build_dataloaders methods. |
|
""" |
|
def __init__(self, cfg: omegaconf.DictConfig): |
|
super().__init__() |
|
self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}") |
|
self.logger.info(f"All XP logs are stored in {self.xp.folder}") |
|
self.cfg = cfg |
|
self.device = cfg.device |
|
self.model: nn.Module |
|
self._continue_best_source_keys = ['best_state', 'fsdp_best_state'] |
|
self._fsdp_modules: tp.List[fsdp.FSDP] = [] |
|
self._ema_sources: nn.ModuleDict = nn.ModuleDict() |
|
self.ema: tp.Optional[optim.ModuleDictEMA] = None |
|
self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict() |
|
self._log_updates = self.cfg.logging.get('log_updates', 10) |
|
if self.cfg.logging.log_tensorboard: |
|
self.init_tensorboard(**self.cfg.get('tensorboard')) |
|
if self.cfg.logging.log_wandb and self: |
|
self.init_wandb(**self.cfg.get('wandb')) |
|
|
|
|
|
dtype_best: tp.Optional[torch.dtype] = None |
|
if self.cfg.fsdp.use: |
|
dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) |
|
assert isinstance(dtype_best, torch.dtype) |
|
elif self.cfg.autocast: |
|
dtype_best = getattr(torch, self.cfg.autocast_dtype) |
|
assert isinstance(dtype_best, torch.dtype) |
|
self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best) |
|
|
|
self.fsdp_best_state: tp.Dict[str, tp.Any] = {} |
|
self.register_stateful('best_state', 'fsdp_best_state') |
|
self._new_best_state: bool = False |
|
|
|
self.build_dataloaders() |
|
if self.cfg.execute_only is None: |
|
assert 'train' in self.dataloaders, "The train dataset split must be provided." |
|
assert 'valid' in self.dataloaders, "The valid dataset split must be provided." |
|
self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0 |
|
if self.cfg.optim.updates_per_epoch: |
|
self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch |
|
self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs |
|
|
|
self.build_model() |
|
self.logger.info("Model hash: %s", model_hash(self.model)) |
|
assert 'model' in self.stateful.sources, \ |
|
"Please register the model to stateful with self.register_stateful('model') in build_model." |
|
self.profiler = Profiler(self.model, **self.cfg.profiler) |
|
self.initialize_ema() |
|
self.register_stateful('ema') |
|
assert self.ema is None or 'ema' in self.stateful.sources, \ |
|
"Please register the ema to stateful with self.register_stateful('ema') in build_model." |
|
self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock) |
|
|
|
model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6 |
|
|
|
|
|
mem_usage = model_size * 4 * 4 / 1000 |
|
self.logger.info("Model size: %.2f M params", model_size) |
|
self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage) |
|
|
|
@property |
|
def autocast(self): |
|
"""Convenient autocast (or not) using the solver configuration.""" |
|
return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype) |
|
|
|
def _get_state_source(self, name) -> flashy.state.StateDictSource: |
|
|
|
return self.stateful.sources[name] |
|
|
|
@property |
|
def best_metric_name(self) -> tp.Optional[str]: |
|
"""Metric name used to identify the best state. This metric should be stored in the metrics |
|
used on the stage for best state identification (most likely, `valid`). If None, then |
|
no best state is saved. |
|
""" |
|
return None |
|
|
|
def register_best_state(self, *args: str): |
|
"""Register state sources in `BestStateDictManager` to keep their best states along with their |
|
latest states. The best state will be used at evaluation stages instead of the latest states. |
|
|
|
Shortcut around `BestStateDictManager.register` method. You can pass any number of |
|
attribute, included nested attributes and those will be included into the checkpoints |
|
and automatically restored when `BaseSolver.restore` is called. |
|
""" |
|
for name in args: |
|
state_source = self._get_state_source(name) |
|
assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!" |
|
self.best_state.register(name, state_source) |
|
|
|
def register_ema(self, *args: str): |
|
"""Register state sources for exponential moving average. |
|
|
|
The registered sources are used to instantiate a ModuleDictEMA instance. |
|
The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called |
|
and swapped with the original state sources with self.swap_ema_state() method. |
|
|
|
Usage: |
|
self.register_ema('model') |
|
""" |
|
assert self.ema is None, "Cannot register state source to already instantiated EMA." |
|
for name in args: |
|
self._ema_sources[name] = getattr(self, name) |
|
|
|
def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): |
|
model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs) |
|
if isinstance(model, fsdp.FSDP): |
|
self._fsdp_modules.append(model) |
|
return model |
|
|
|
def update_best_state_from_stage(self, stage_name: str = 'valid'): |
|
"""Update latest best state based on pending metrics of a given stage. This method relies |
|
on the `BestStateDictManager.update` method to update the best state_dict with latest weights |
|
if the registered states happen to match to the best performing setup. |
|
""" |
|
if self.best_metric_name is None: |
|
|
|
self._new_best_state = True |
|
self.logger.info("Updating best state with current state.") |
|
else: |
|
assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found." |
|
assert self.best_metric_name in self._pending_metrics[stage_name], \ |
|
f"Best metric not found in {stage_name} metrics. Cannot register best state" |
|
current_score = self._pending_metrics[stage_name][self.best_metric_name] |
|
all_best_metric_scores = [ |
|
past_metrics[stage_name][self.best_metric_name] |
|
for past_metrics in self.history |
|
] |
|
all_best_metric_scores.append(current_score) |
|
best_score = min(all_best_metric_scores) |
|
self._new_best_state = current_score == best_score |
|
if self._new_best_state: |
|
old_best = min(all_best_metric_scores[:-1] + [float('inf')]) |
|
self.logger.info( |
|
f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})") |
|
|
|
if self._new_best_state: |
|
if self.cfg.fsdp.use: |
|
|
|
|
|
with fsdp.switch_to_full_state_dict(self._fsdp_modules): |
|
for name in self.best_state.states.keys(): |
|
state_source = self._get_state_source(name) |
|
self.best_state.update(name, state_source) |
|
|
|
self.fsdp_best_state.update(self.best_state.state_dict()) |
|
|
|
|
|
for name in self.best_state.states.keys(): |
|
state_source = self._get_state_source(name) |
|
self.best_state.update(name, state_source) |
|
|
|
def _load_new_state_dict(self, state_dict: dict) -> dict: |
|
old_states = {} |
|
for name, new_state in state_dict.items(): |
|
state_source = self._get_state_source(name) |
|
old_states[name] = copy_state(state_source.state_dict()) |
|
state_source.load_state_dict(new_state) |
|
return old_states |
|
|
|
@contextmanager |
|
def swap_best_state(self): |
|
self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}") |
|
old_states = self._load_new_state_dict(self.best_state.state_dict()) |
|
try: |
|
yield |
|
finally: |
|
self.logger.debug("Swapping back from best to original state") |
|
for name, old_state in old_states.items(): |
|
state_source = self._get_state_source(name) |
|
state_source.load_state_dict(old_state) |
|
|
|
@contextmanager |
|
def swap_ema_state(self): |
|
if self.ema is None: |
|
yield |
|
else: |
|
ema_state_dict = self.ema.state_dict()['state'] |
|
self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}") |
|
old_states = self._load_new_state_dict(ema_state_dict) |
|
try: |
|
yield |
|
finally: |
|
self.logger.debug("Swapping back from EMA state to original state") |
|
for name, old_state in old_states.items(): |
|
state_source = self._get_state_source(name) |
|
state_source.load_state_dict(old_state) |
|
|
|
@property |
|
def is_training(self): |
|
return self.current_stage == 'train' |
|
|
|
def log_model_summary(self, model: nn.Module): |
|
"""Log model summary, architecture and size of the model.""" |
|
self.logger.info(model) |
|
mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20 |
|
self.logger.info("Size: %.1f MB", mb) |
|
|
|
@abstractmethod |
|
def build_model(self): |
|
"""Method to implement to initialize model.""" |
|
... |
|
|
|
def initialize_ema(self): |
|
"""Initialize exponential moving average with the registered sources. |
|
EMA object is created if the optim.ema.model.decay value is non-null. |
|
""" |
|
from .builders import get_ema |
|
self.ema = get_ema(self._ema_sources, self.cfg.optim.ema) |
|
if self.ema is None: |
|
self.logger.info('No EMA on the model.') |
|
else: |
|
assert self.cfg.optim.ema.updates > 0 |
|
self.logger.info( |
|
f'Initializing EMA on the model with decay = {self.ema.decay}' |
|
f' every {self.cfg.optim.ema.updates} updates' |
|
) |
|
|
|
@abstractmethod |
|
def build_dataloaders(self): |
|
"""Method to implement to initialize dataloaders.""" |
|
... |
|
|
|
@abstractmethod |
|
def show(self): |
|
"""Method to log any information without running the job.""" |
|
... |
|
|
|
@property |
|
def log_updates(self): |
|
|
|
return self._log_updates |
|
|
|
def checkpoint_path(self, **kwargs): |
|
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) |
|
return self.folder / checkpoint.checkpoint_name(**kwargs) |
|
|
|
def epoch_checkpoint_path(self, epoch: int, **kwargs): |
|
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) |
|
return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs) |
|
|
|
def checkpoint_path_with_name(self, name: str, **kwargs): |
|
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) |
|
return self.folder / checkpoint.checkpoint_name(name=name, **kwargs) |
|
|
|
def save_checkpoints(self): |
|
"""Save checkpoint, optionally keeping a copy for a given epoch.""" |
|
is_sharded = self.cfg.fsdp.use |
|
if not flashy.distrib.is_rank_zero() and not is_sharded: |
|
return |
|
self.logger.info("Model hash: %s", model_hash(self.model)) |
|
state = self.state_dict() |
|
epoch = self.epoch - 1 |
|
|
|
|
|
if self.cfg.checkpoint.save_every: |
|
if epoch % self.cfg.checkpoint.save_every == 0: |
|
minimal_state = state |
|
if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0: |
|
minimal_state = { |
|
name: source for name, source in state.items() |
|
if name in self.cfg.checkpoint.keep_every_states |
|
} |
|
epoch_checkpoint_path = self.epoch_checkpoint_path(epoch) |
|
checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded) |
|
|
|
|
|
if self.cfg.checkpoint.save_last: |
|
last_checkpoint_path = self.checkpoint_path() |
|
checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded) |
|
|
|
|
|
checkpoint.flush_stale_checkpoints(self.checkpoint_path()) |
|
|
|
def load_from_pretrained(self, name: str) -> dict: |
|
raise NotImplementedError("Solver does not provide a way to load pretrained models.") |
|
|
|
def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]: |
|
"""Load last checkpoint or the one specified in continue_from. |
|
|
|
Args: |
|
load_best (bool): Whether to load from best state dict or not. |
|
Best state dict is always used when not loading the current xp. |
|
ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`. |
|
Returns: |
|
state (dict, optional): The loaded state dictionary. |
|
""" |
|
|
|
is_sharded = self.cfg.fsdp.use |
|
load_from_path: tp.Optional[Path] = None |
|
checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None |
|
|
|
if load_best: |
|
self.logger.info("Trying to load state_dict from best state.") |
|
|
|
state: tp.Optional[dict] = None |
|
rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False) |
|
current_checkpoint_path = self.checkpoint_path() |
|
_pretrained_prefix = '//pretrained/' |
|
continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix) |
|
if rank0_checkpoint_path.exists(): |
|
self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}") |
|
load_from_path = current_checkpoint_path |
|
checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path) |
|
checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP |
|
elif self.cfg.continue_from and not continue_pretrained: |
|
self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}") |
|
|
|
load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False) |
|
if load_from_path is None: |
|
self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from) |
|
raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}') |
|
checkpoint_source = checkpoint.CheckpointSource.OTHER |
|
|
|
if load_from_path is not None: |
|
state = checkpoint.load_checkpoint(load_from_path, is_sharded) |
|
elif continue_pretrained: |
|
self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.") |
|
state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):]) |
|
checkpoint_source = checkpoint.CheckpointSource.PRETRAINED |
|
load_best = True |
|
|
|
|
|
if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP: |
|
assert state is not None |
|
self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.") |
|
load_best = True |
|
state = {key: state[key] for key in self._continue_best_source_keys if key in state} |
|
|
|
|
|
if 'fsdp_best_state' in state and state['fsdp_best_state']: |
|
state.pop('best_state', None) |
|
self.logger.info("... Loaded checkpoint has FSDP best state") |
|
|
|
|
|
elif self.cfg.fsdp.use: |
|
if 'fsdp_best_state' not in state or not state['fsdp_best_state']: |
|
|
|
state['fsdp_best_state'] = state.pop('best_state') |
|
self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state") |
|
|
|
if state is not None: |
|
if load_best: |
|
self.logger.info("Ignoring keys when loading best %r", ignore_state_keys) |
|
for key in set(ignore_state_keys): |
|
if key in state: |
|
state.pop(key) |
|
has_best_state = 'best_state' in state or 'fsdp_best_state' in state |
|
assert has_best_state, ("Trying to load best state but neither 'best_state'", |
|
" or 'fsdp_best_state' found in checkpoints.") |
|
self.load_state_dict(state) |
|
|
|
|
|
|
|
epoch = float(self.epoch) |
|
avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch'] |
|
if avg_epoch != epoch: |
|
raise RuntimeError( |
|
f"Inconsistent loading of checkpoints happened, our epoch is {epoch} " |
|
f"but average of epochs is {avg_epoch}, at least one gpu must have a " |
|
"different epoch number.") |
|
|
|
|
|
|
|
if load_best: |
|
self.logger.info("Loading state_dict from best state.") |
|
if not self.cfg.fsdp.use and self.fsdp_best_state: |
|
|
|
self.logger.info("... Loading from FSDP best state dict.") |
|
self.best_state.load_state_dict(self.fsdp_best_state) |
|
|
|
|
|
if self.cfg.fsdp.use: |
|
self.logger.info("FSDP is used, loading from FSDP best state.") |
|
with fsdp.switch_to_full_state_dict(self._fsdp_modules): |
|
|
|
self.load_state_dict(self.fsdp_best_state) |
|
else: |
|
|
|
self._load_new_state_dict(self.best_state.state_dict()) |
|
|
|
|
|
|
|
if self.ema is not None: |
|
self.logger.info("Re-initializing EMA from best state") |
|
self.initialize_ema() |
|
|
|
if self.cfg.fsdp.use: |
|
self.logger.info("Re-initializing best state after using FSDP best state.") |
|
for name in self.best_state.states.keys(): |
|
state_source = self._get_state_source(name) |
|
self.best_state.update(name, state_source) |
|
|
|
return state |
|
|
|
def restore(self, load_best: bool = False, replay_metrics: bool = False, |
|
ignore_state_keys: tp.List[str] = []) -> bool: |
|
"""Restore the status of a solver for a given xp. |
|
|
|
Args: |
|
load_best (bool): if `True`, load the best state from the checkpoint. |
|
replay_metrics (bool): if `True`, logs all the metrics from past epochs. |
|
ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`. |
|
""" |
|
self.logger.info("Restoring weights and history.") |
|
restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys) |
|
|
|
self.logger.info("Model hash: %s", model_hash(self.model)) |
|
|
|
if replay_metrics and len(self.history) > 0: |
|
self.logger.info("Replaying past metrics...") |
|
for epoch, stages in enumerate(self.history): |
|
for stage_name, metrics in stages.items(): |
|
|
|
|
|
self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch', |
|
formatter=self.get_formatter(stage_name)) |
|
return restored_checkpoints is not None |
|
|
|
def commit(self, save_checkpoints: bool = True): |
|
"""Commit metrics to dora and save checkpoints at the end of an epoch.""" |
|
|
|
self.history.append(self._pending_metrics) |
|
if save_checkpoints: |
|
self.save_checkpoints() |
|
self._start_epoch() |
|
if flashy.distrib.is_rank_zero(): |
|
self.xp.link.update_history(self.history) |
|
|
|
def run_epoch(self): |
|
"""Run a single epoch with all stages. |
|
|
|
Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. |
|
Children solvers can extend this method with custom behavior, e.g.: |
|
|
|
def run_epoch(self): |
|
... # custom code |
|
super().run_epoch() |
|
... # custom code |
|
""" |
|
self.run_stage('train', self.train) |
|
with torch.no_grad(): |
|
with self.swap_ema_state(): |
|
self.run_stage('valid', self.valid) |
|
|
|
self.update_best_state_from_stage('valid') |
|
with self.swap_best_state(): |
|
if self.should_run_stage('evaluate'): |
|
self.run_stage('evaluate', self.evaluate) |
|
if self.should_run_stage('generate'): |
|
self.run_stage('generate', with_rank_rng()(self.generate)) |
|
|
|
def run(self): |
|
"""Training loop.""" |
|
assert len(self.state_dict()) > 0 |
|
self.restore(replay_metrics=True) |
|
self.log_hyperparams(dict_from_config(self.cfg)) |
|
for epoch in range(self.epoch, self.cfg.optim.epochs + 1): |
|
if self.should_stop_training(): |
|
return |
|
self.run_epoch() |
|
|
|
self.commit() |
|
|
|
def should_stop_training(self) -> bool: |
|
"""Check whether we should stop training or not.""" |
|
return self.epoch > self.cfg.optim.epochs |
|
|
|
def should_run_stage(self, stage_name) -> bool: |
|
"""Check whether we want to run the specified stages.""" |
|
stage_every = self.cfg[stage_name].get('every', None) |
|
is_last_epoch = self.epoch == self.cfg.optim.epochs |
|
is_epoch_every = (stage_every and self.epoch % stage_every == 0) |
|
return is_last_epoch or is_epoch_every |
|
|
|
@abstractmethod |
|
def run_step(self, idx: int, batch: tp.Any, metrics: dict): |
|
"""Perform one training or valid step on a given batch.""" |
|
... |
|
|
|
def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): |
|
"""Common logic for train and valid stages.""" |
|
self.model.train(self.is_training) |
|
|
|
loader = self.dataloaders[dataset_split] |
|
|
|
if flashy.distrib.world_size() > 1 \ |
|
and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler): |
|
loader.sampler.set_epoch(self.epoch) |
|
updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader) |
|
if self.cfg.benchmark_no_load: |
|
self.logger.warning("Fake loading for benchmarking: re-using first batch") |
|
batch = next(iter(loader)) |
|
loader = [batch] * updates_per_epoch |
|
lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates) |
|
average = flashy.averager() |
|
instant_average = flashy.averager() |
|
metrics: dict = {} |
|
|
|
with self.profiler, self.deadlock_detect: |
|
for idx, batch in enumerate(lp): |
|
self.deadlock_detect.update('batch') |
|
if idx >= updates_per_epoch: |
|
break |
|
metrics = {} |
|
metrics = self.run_step(idx, batch, metrics) |
|
self.deadlock_detect.update('step') |
|
|
|
if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0: |
|
self.logger.debug("EMA model step") |
|
self.ema.step() |
|
self.deadlock_detect.update('ema') |
|
self.profiler.step() |
|
instant_metrics = instant_average(metrics) |
|
if lp.update(**instant_metrics): |
|
instant_average = flashy.averager() |
|
metrics = average(metrics) |
|
self.deadlock_detect.update('end_batch') |
|
|
|
metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch) |
|
return metrics |
|
|
|
def train(self): |
|
"""Train stage.""" |
|
return self.common_train_valid('train') |
|
|
|
def valid(self): |
|
"""Valid stage.""" |
|
return self.common_train_valid('valid') |
|
|
|
@abstractmethod |
|
def evaluate(self): |
|
"""Evaluate stage.""" |
|
... |
|
|
|
@abstractmethod |
|
def generate(self): |
|
"""Generate stage.""" |
|
... |
|
|
|
def run_one_stage(self, stage_name: str): |
|
"""Run only the specified stage. |
|
This method is useful to only generate samples from a trained experiment |
|
or rerun the validation or evaluation stages. |
|
""" |
|
fn = { |
|
'generate': with_rank_rng()(self.generate), |
|
'evaluate': self.evaluate, |
|
'valid': self.valid, |
|
} |
|
if stage_name not in fn: |
|
raise ValueError(f'Trying to run stage {stage_name} is not supported.') |
|
assert len(self.state_dict()) > 0 |
|
self._start_epoch() |
|
with torch.no_grad(), self.swap_best_state(): |
|
self.run_stage(stage_name, fn[stage_name]) |
|
if not self.cfg.execute_inplace: |
|
self.commit(save_checkpoints=False) |
|
|
|
@staticmethod |
|
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, |
|
device: tp.Optional[str] = None, autocast: bool = True, |
|
batch_size: tp.Optional[int] = None, |
|
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, |
|
**kwargs): |
|
"""Mostly a convenience function around audiocraft.train.get_solver_from_sig, |
|
populating all the proper param, deactivating EMA, FSDP, loading the best state, |
|
basically all you need to get a solver ready to "play" with in single GPU mode |
|
and with minimal memory overhead. |
|
|
|
Args: |
|
sig (str): signature to load. |
|
dtype (str or None): potential dtype, as a string, i.e. 'float16'. |
|
device (str or None): potential device, as a string, i.e. 'cuda'. |
|
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. |
|
""" |
|
from audiocraft import train |
|
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} |
|
our_override_cfg['autocast'] = autocast |
|
if dtype is not None: |
|
our_override_cfg['dtype'] = dtype |
|
if device is not None: |
|
our_override_cfg['device'] = device |
|
if batch_size is not None: |
|
our_override_cfg['dataset'] = {'batch_size': batch_size} |
|
if override_cfg is None: |
|
override_cfg = {} |
|
override_cfg = omegaconf.OmegaConf.merge( |
|
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) |
|
solver = train.get_solver_from_sig( |
|
sig, override_cfg=override_cfg, |
|
load_best=True, disable_fsdp=True, |
|
ignore_state_keys=['optimizer', 'ema'], **kwargs) |
|
solver.model.eval() |
|
return solver |
|
|