Spaces:
Sleeping
Sleeping
unpairedelectron07
commited on
Commit
•
e3061ad
1
Parent(s):
699b46d
Upload 7 files
Browse files- audiocraft/solvers/audiogen.py +19 -0
- audiocraft/solvers/base.py +631 -0
- audiocraft/solvers/builders.py +366 -0
- audiocraft/solvers/compression.py +328 -0
- audiocraft/solvers/diffusion.py +279 -0
- audiocraft/solvers/magnet.py +276 -0
- audiocraft/solvers/musicgen.py +721 -0
audiocraft/solvers/audiogen.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from . import builders, musicgen
|
8 |
+
|
9 |
+
|
10 |
+
class AudioGenSolver(musicgen.MusicGenSolver):
|
11 |
+
"""Solver for AudioGen re-implementation training task.
|
12 |
+
|
13 |
+
Note that this implementation does not strictly follows
|
14 |
+
the method proposed in https://arxiv.org/abs/2209.15352
|
15 |
+
but is derived from MusicGen's training pipeline.
|
16 |
+
|
17 |
+
More information can be found in the AudioGen model card.
|
18 |
+
"""
|
19 |
+
DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
|
audiocraft/solvers/base.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from pathlib import Path
|
10 |
+
import typing as tp
|
11 |
+
|
12 |
+
import flashy
|
13 |
+
import omegaconf
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from .. import optim
|
18 |
+
from ..optim import fsdp
|
19 |
+
from ..utils import checkpoint
|
20 |
+
from ..utils.autocast import TorchAutocast
|
21 |
+
from ..utils.best_state import BestStateDictManager
|
22 |
+
from ..utils.deadlock import DeadlockDetect
|
23 |
+
from ..utils.profiler import Profiler
|
24 |
+
from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng
|
25 |
+
|
26 |
+
|
27 |
+
class StandardSolver(ABC, flashy.BaseSolver):
|
28 |
+
"""Standard solver for AudioCraft.
|
29 |
+
|
30 |
+
The standard solver implements a base training loop with the following stages:
|
31 |
+
train, valid, evaluate and generate that are expected to be all defined for
|
32 |
+
solvers in AudioCraft. It also provides a nice default management of Dora history replay,
|
33 |
+
checkpoint management across epoch, and logging configuration.
|
34 |
+
|
35 |
+
AudioCraft solvers must inherit from the StandardSolver and define the methods
|
36 |
+
associated to each stage as well as the show, build_model and build_dataloaders methods.
|
37 |
+
"""
|
38 |
+
def __init__(self, cfg: omegaconf.DictConfig):
|
39 |
+
super().__init__()
|
40 |
+
self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
|
41 |
+
self.logger.info(f"All XP logs are stored in {self.xp.folder}")
|
42 |
+
self.cfg = cfg
|
43 |
+
self.device = cfg.device
|
44 |
+
self.model: nn.Module
|
45 |
+
self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
|
46 |
+
self._fsdp_modules: tp.List[fsdp.FSDP] = []
|
47 |
+
self._ema_sources: nn.ModuleDict = nn.ModuleDict()
|
48 |
+
self.ema: tp.Optional[optim.ModuleDictEMA] = None
|
49 |
+
self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
|
50 |
+
self._log_updates = self.cfg.logging.get('log_updates', 10)
|
51 |
+
if self.cfg.logging.log_tensorboard:
|
52 |
+
self.init_tensorboard(**self.cfg.get('tensorboard'))
|
53 |
+
if self.cfg.logging.log_wandb and self:
|
54 |
+
self.init_wandb(**self.cfg.get('wandb'))
|
55 |
+
# keep a copy of the best performing state for stateful objects
|
56 |
+
# used for evaluation and generation stages
|
57 |
+
dtype_best: tp.Optional[torch.dtype] = None
|
58 |
+
if self.cfg.fsdp.use:
|
59 |
+
dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore
|
60 |
+
assert isinstance(dtype_best, torch.dtype)
|
61 |
+
elif self.cfg.autocast:
|
62 |
+
dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore
|
63 |
+
assert isinstance(dtype_best, torch.dtype)
|
64 |
+
self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
|
65 |
+
# Hacky support for keeping a copy of the full best state in rank0.
|
66 |
+
self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
|
67 |
+
self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict
|
68 |
+
self._new_best_state: bool = False # should save a new checkpoint
|
69 |
+
# instantiate datasets and appropriate number of updates per epoch
|
70 |
+
self.build_dataloaders()
|
71 |
+
if self.cfg.execute_only is None:
|
72 |
+
assert 'train' in self.dataloaders, "The train dataset split must be provided."
|
73 |
+
assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
|
74 |
+
self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
|
75 |
+
if self.cfg.optim.updates_per_epoch:
|
76 |
+
self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
|
77 |
+
self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
|
78 |
+
# instantiate model & exponential moving average on the model
|
79 |
+
self.build_model()
|
80 |
+
self.logger.info("Model hash: %s", model_hash(self.model))
|
81 |
+
assert 'model' in self.stateful.sources, \
|
82 |
+
"Please register the model to stateful with self.register_stateful('model') in build_model."
|
83 |
+
self.profiler = Profiler(self.model, **self.cfg.profiler)
|
84 |
+
self.initialize_ema()
|
85 |
+
self.register_stateful('ema')
|
86 |
+
assert self.ema is None or 'ema' in self.stateful.sources, \
|
87 |
+
"Please register the ema to stateful with self.register_stateful('ema') in build_model."
|
88 |
+
self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
|
89 |
+
# basic statistics on the trained model
|
90 |
+
model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
|
91 |
+
# one copy of grad, one copy of momentum, one copy of denominator and model weights.
|
92 |
+
# and 4 bytes for each float!
|
93 |
+
mem_usage = model_size * 4 * 4 / 1000
|
94 |
+
self.logger.info("Model size: %.2f M params", model_size)
|
95 |
+
self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
|
96 |
+
|
97 |
+
@property
|
98 |
+
def autocast(self):
|
99 |
+
"""Convenient autocast (or not) using the solver configuration."""
|
100 |
+
return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
|
101 |
+
|
102 |
+
def _get_state_source(self, name) -> flashy.state.StateDictSource:
|
103 |
+
# Internal utility to get a state source from the solver
|
104 |
+
return self.stateful.sources[name]
|
105 |
+
|
106 |
+
@property
|
107 |
+
def best_metric_name(self) -> tp.Optional[str]:
|
108 |
+
"""Metric name used to identify the best state. This metric should be stored in the metrics
|
109 |
+
used on the stage for best state identification (most likely, `valid`). If None, then
|
110 |
+
no best state is saved.
|
111 |
+
"""
|
112 |
+
return None
|
113 |
+
|
114 |
+
def register_best_state(self, *args: str):
|
115 |
+
"""Register state sources in `BestStateDictManager` to keep their best states along with their
|
116 |
+
latest states. The best state will be used at evaluation stages instead of the latest states.
|
117 |
+
|
118 |
+
Shortcut around `BestStateDictManager.register` method. You can pass any number of
|
119 |
+
attribute, included nested attributes and those will be included into the checkpoints
|
120 |
+
and automatically restored when `BaseSolver.restore` is called.
|
121 |
+
"""
|
122 |
+
for name in args:
|
123 |
+
state_source = self._get_state_source(name)
|
124 |
+
assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
|
125 |
+
self.best_state.register(name, state_source)
|
126 |
+
|
127 |
+
def register_ema(self, *args: str):
|
128 |
+
"""Register state sources for exponential moving average.
|
129 |
+
|
130 |
+
The registered sources are used to instantiate a ModuleDictEMA instance.
|
131 |
+
The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
|
132 |
+
and swapped with the original state sources with self.swap_ema_state() method.
|
133 |
+
|
134 |
+
Usage:
|
135 |
+
self.register_ema('model')
|
136 |
+
"""
|
137 |
+
assert self.ema is None, "Cannot register state source to already instantiated EMA."
|
138 |
+
for name in args:
|
139 |
+
self._ema_sources[name] = getattr(self, name)
|
140 |
+
|
141 |
+
def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
|
142 |
+
model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
|
143 |
+
if isinstance(model, fsdp.FSDP):
|
144 |
+
self._fsdp_modules.append(model)
|
145 |
+
return model
|
146 |
+
|
147 |
+
def update_best_state_from_stage(self, stage_name: str = 'valid'):
|
148 |
+
"""Update latest best state based on pending metrics of a given stage. This method relies
|
149 |
+
on the `BestStateDictManager.update` method to update the best state_dict with latest weights
|
150 |
+
if the registered states happen to match to the best performing setup.
|
151 |
+
"""
|
152 |
+
if self.best_metric_name is None:
|
153 |
+
# when no best metric is defined, the last state is always the best
|
154 |
+
self._new_best_state = True
|
155 |
+
self.logger.info("Updating best state with current state.")
|
156 |
+
else:
|
157 |
+
assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
|
158 |
+
assert self.best_metric_name in self._pending_metrics[stage_name], \
|
159 |
+
f"Best metric not found in {stage_name} metrics. Cannot register best state"
|
160 |
+
current_score = self._pending_metrics[stage_name][self.best_metric_name]
|
161 |
+
all_best_metric_scores = [
|
162 |
+
past_metrics[stage_name][self.best_metric_name]
|
163 |
+
for past_metrics in self.history
|
164 |
+
]
|
165 |
+
all_best_metric_scores.append(current_score)
|
166 |
+
best_score = min(all_best_metric_scores)
|
167 |
+
self._new_best_state = current_score == best_score
|
168 |
+
if self._new_best_state:
|
169 |
+
old_best = min(all_best_metric_scores[:-1] + [float('inf')])
|
170 |
+
self.logger.info(
|
171 |
+
f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
|
172 |
+
|
173 |
+
if self._new_best_state:
|
174 |
+
if self.cfg.fsdp.use:
|
175 |
+
# this will give an empty state dict on all ranks but the rank 0
|
176 |
+
# which will have a copy in memory of the full model.
|
177 |
+
with fsdp.switch_to_full_state_dict(self._fsdp_modules):
|
178 |
+
for name in self.best_state.states.keys():
|
179 |
+
state_source = self._get_state_source(name)
|
180 |
+
self.best_state.update(name, state_source)
|
181 |
+
# we save to a different dict.
|
182 |
+
self.fsdp_best_state.update(self.best_state.state_dict())
|
183 |
+
# We cannot efficiently load fsdp_best_state when using FSDP,
|
184 |
+
# so we have do do a second pass, with the local shards.
|
185 |
+
for name in self.best_state.states.keys():
|
186 |
+
state_source = self._get_state_source(name)
|
187 |
+
self.best_state.update(name, state_source)
|
188 |
+
|
189 |
+
def _load_new_state_dict(self, state_dict: dict) -> dict:
|
190 |
+
old_states = {}
|
191 |
+
for name, new_state in state_dict.items():
|
192 |
+
state_source = self._get_state_source(name)
|
193 |
+
old_states[name] = copy_state(state_source.state_dict())
|
194 |
+
state_source.load_state_dict(new_state)
|
195 |
+
return old_states
|
196 |
+
|
197 |
+
@contextmanager
|
198 |
+
def swap_best_state(self):
|
199 |
+
self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
|
200 |
+
old_states = self._load_new_state_dict(self.best_state.state_dict())
|
201 |
+
try:
|
202 |
+
yield
|
203 |
+
finally:
|
204 |
+
self.logger.debug("Swapping back from best to original state")
|
205 |
+
for name, old_state in old_states.items():
|
206 |
+
state_source = self._get_state_source(name)
|
207 |
+
state_source.load_state_dict(old_state)
|
208 |
+
|
209 |
+
@contextmanager
|
210 |
+
def swap_ema_state(self):
|
211 |
+
if self.ema is None:
|
212 |
+
yield
|
213 |
+
else:
|
214 |
+
ema_state_dict = self.ema.state_dict()['state']
|
215 |
+
self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
|
216 |
+
old_states = self._load_new_state_dict(ema_state_dict)
|
217 |
+
try:
|
218 |
+
yield
|
219 |
+
finally:
|
220 |
+
self.logger.debug("Swapping back from EMA state to original state")
|
221 |
+
for name, old_state in old_states.items():
|
222 |
+
state_source = self._get_state_source(name)
|
223 |
+
state_source.load_state_dict(old_state)
|
224 |
+
|
225 |
+
@property
|
226 |
+
def is_training(self):
|
227 |
+
return self.current_stage == 'train'
|
228 |
+
|
229 |
+
def log_model_summary(self, model: nn.Module):
|
230 |
+
"""Log model summary, architecture and size of the model."""
|
231 |
+
self.logger.info(model)
|
232 |
+
mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
|
233 |
+
self.logger.info("Size: %.1f MB", mb)
|
234 |
+
|
235 |
+
@abstractmethod
|
236 |
+
def build_model(self):
|
237 |
+
"""Method to implement to initialize model."""
|
238 |
+
...
|
239 |
+
|
240 |
+
def initialize_ema(self):
|
241 |
+
"""Initialize exponential moving average with the registered sources.
|
242 |
+
EMA object is created if the optim.ema.model.decay value is non-null.
|
243 |
+
"""
|
244 |
+
from .builders import get_ema
|
245 |
+
self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
|
246 |
+
if self.ema is None:
|
247 |
+
self.logger.info('No EMA on the model.')
|
248 |
+
else:
|
249 |
+
assert self.cfg.optim.ema.updates > 0
|
250 |
+
self.logger.info(
|
251 |
+
f'Initializing EMA on the model with decay = {self.ema.decay}'
|
252 |
+
f' every {self.cfg.optim.ema.updates} updates'
|
253 |
+
)
|
254 |
+
|
255 |
+
@abstractmethod
|
256 |
+
def build_dataloaders(self):
|
257 |
+
"""Method to implement to initialize dataloaders."""
|
258 |
+
...
|
259 |
+
|
260 |
+
@abstractmethod
|
261 |
+
def show(self):
|
262 |
+
"""Method to log any information without running the job."""
|
263 |
+
...
|
264 |
+
|
265 |
+
@property
|
266 |
+
def log_updates(self):
|
267 |
+
# convenient access to log updates
|
268 |
+
return self._log_updates
|
269 |
+
|
270 |
+
def checkpoint_path(self, **kwargs):
|
271 |
+
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
|
272 |
+
return self.folder / checkpoint.checkpoint_name(**kwargs)
|
273 |
+
|
274 |
+
def epoch_checkpoint_path(self, epoch: int, **kwargs):
|
275 |
+
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
|
276 |
+
return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
|
277 |
+
|
278 |
+
def checkpoint_path_with_name(self, name: str, **kwargs):
|
279 |
+
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
|
280 |
+
return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
|
281 |
+
|
282 |
+
def save_checkpoints(self):
|
283 |
+
"""Save checkpoint, optionally keeping a copy for a given epoch."""
|
284 |
+
is_sharded = self.cfg.fsdp.use
|
285 |
+
if not flashy.distrib.is_rank_zero() and not is_sharded:
|
286 |
+
return
|
287 |
+
self.logger.info("Model hash: %s", model_hash(self.model))
|
288 |
+
state = self.state_dict()
|
289 |
+
epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here
|
290 |
+
|
291 |
+
# save minimal state_dict as new checkpoint every X epoch
|
292 |
+
if self.cfg.checkpoint.save_every:
|
293 |
+
if epoch % self.cfg.checkpoint.save_every == 0:
|
294 |
+
minimal_state = state
|
295 |
+
if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
|
296 |
+
minimal_state = {
|
297 |
+
name: source for name, source in state.items()
|
298 |
+
if name in self.cfg.checkpoint.keep_every_states
|
299 |
+
}
|
300 |
+
epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
|
301 |
+
checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
|
302 |
+
|
303 |
+
# save checkpoint as latest checkpoint
|
304 |
+
if self.cfg.checkpoint.save_last:
|
305 |
+
last_checkpoint_path = self.checkpoint_path()
|
306 |
+
checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
|
307 |
+
|
308 |
+
# flush any stale checkpoint to reduce disk footprint
|
309 |
+
checkpoint.flush_stale_checkpoints(self.checkpoint_path())
|
310 |
+
|
311 |
+
def load_from_pretrained(self, name: str) -> dict:
|
312 |
+
raise NotImplementedError("Solver does not provide a way to load pretrained models.")
|
313 |
+
|
314 |
+
def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
|
315 |
+
"""Load last checkpoint or the one specified in continue_from.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
load_best (bool): Whether to load from best state dict or not.
|
319 |
+
Best state dict is always used when not loading the current xp.
|
320 |
+
ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
|
321 |
+
Returns:
|
322 |
+
state (dict, optional): The loaded state dictionary.
|
323 |
+
"""
|
324 |
+
# load checkpoints from xp folder or cfg.continue_from
|
325 |
+
is_sharded = self.cfg.fsdp.use
|
326 |
+
load_from_path: tp.Optional[Path] = None
|
327 |
+
checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
|
328 |
+
|
329 |
+
if load_best:
|
330 |
+
self.logger.info("Trying to load state_dict from best state.")
|
331 |
+
|
332 |
+
state: tp.Optional[dict] = None
|
333 |
+
rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
|
334 |
+
current_checkpoint_path = self.checkpoint_path()
|
335 |
+
_pretrained_prefix = '//pretrained/'
|
336 |
+
continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
|
337 |
+
if rank0_checkpoint_path.exists():
|
338 |
+
self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
|
339 |
+
load_from_path = current_checkpoint_path
|
340 |
+
checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
|
341 |
+
checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
|
342 |
+
elif self.cfg.continue_from and not continue_pretrained:
|
343 |
+
self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
|
344 |
+
# we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
|
345 |
+
load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
|
346 |
+
if load_from_path is None:
|
347 |
+
self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
|
348 |
+
raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
|
349 |
+
checkpoint_source = checkpoint.CheckpointSource.OTHER
|
350 |
+
|
351 |
+
if load_from_path is not None:
|
352 |
+
state = checkpoint.load_checkpoint(load_from_path, is_sharded)
|
353 |
+
elif continue_pretrained:
|
354 |
+
self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
|
355 |
+
state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
|
356 |
+
checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
|
357 |
+
load_best = True
|
358 |
+
|
359 |
+
# checkpoints are not from the current xp, we only retrieve the best state
|
360 |
+
if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
|
361 |
+
assert state is not None
|
362 |
+
self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
|
363 |
+
load_best = True
|
364 |
+
state = {key: state[key] for key in self._continue_best_source_keys if key in state}
|
365 |
+
# loaded checkpoints are FSDP checkpoints: we're reading the best state
|
366 |
+
# from FSDP and we drop the regular best_state
|
367 |
+
if 'fsdp_best_state' in state and state['fsdp_best_state']:
|
368 |
+
state.pop('best_state', None)
|
369 |
+
self.logger.info("... Loaded checkpoint has FSDP best state")
|
370 |
+
# FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
|
371 |
+
# then we're initializing FSDP best state with the regular best state
|
372 |
+
elif self.cfg.fsdp.use:
|
373 |
+
if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
|
374 |
+
# we swap non-FSDP checkpoints best_state to FSDP-compatible best state
|
375 |
+
state['fsdp_best_state'] = state.pop('best_state')
|
376 |
+
self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
|
377 |
+
|
378 |
+
if state is not None:
|
379 |
+
if load_best:
|
380 |
+
self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
|
381 |
+
for key in set(ignore_state_keys):
|
382 |
+
if key in state:
|
383 |
+
state.pop(key)
|
384 |
+
has_best_state = 'best_state' in state or 'fsdp_best_state' in state
|
385 |
+
assert has_best_state, ("Trying to load best state but neither 'best_state'",
|
386 |
+
" or 'fsdp_best_state' found in checkpoints.")
|
387 |
+
self.load_state_dict(state)
|
388 |
+
|
389 |
+
# for FSDP, let's make extra sure nothing bad happened with out of sync
|
390 |
+
# checkpoints across workers.
|
391 |
+
epoch = float(self.epoch)
|
392 |
+
avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
|
393 |
+
if avg_epoch != epoch:
|
394 |
+
raise RuntimeError(
|
395 |
+
f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
|
396 |
+
f"but average of epochs is {avg_epoch}, at least one gpu must have a "
|
397 |
+
"different epoch number.")
|
398 |
+
|
399 |
+
# on load_best, properly reinitialize state_dict, best states and ema
|
400 |
+
# otherwise we load from the current xp and don't alter anything
|
401 |
+
if load_best:
|
402 |
+
self.logger.info("Loading state_dict from best state.")
|
403 |
+
if not self.cfg.fsdp.use and self.fsdp_best_state:
|
404 |
+
# loading from an FSDP checkpoint but with FSDP deactivated
|
405 |
+
self.logger.info("... Loading from FSDP best state dict.")
|
406 |
+
self.best_state.load_state_dict(self.fsdp_best_state)
|
407 |
+
|
408 |
+
# if load_best, we permanently override the regular state_dict with the best state
|
409 |
+
if self.cfg.fsdp.use:
|
410 |
+
self.logger.info("FSDP is used, loading from FSDP best state.")
|
411 |
+
with fsdp.switch_to_full_state_dict(self._fsdp_modules):
|
412 |
+
# this might be really fragile but okay for now.
|
413 |
+
self.load_state_dict(self.fsdp_best_state)
|
414 |
+
else:
|
415 |
+
# we permanently swap the stateful objects to their best state
|
416 |
+
self._load_new_state_dict(self.best_state.state_dict())
|
417 |
+
|
418 |
+
# the EMA modules should also be instantiated with best state.
|
419 |
+
# the easiest way to do so is to reinitialize a new EMA with best state loaded.
|
420 |
+
if self.ema is not None:
|
421 |
+
self.logger.info("Re-initializing EMA from best state")
|
422 |
+
self.initialize_ema()
|
423 |
+
|
424 |
+
if self.cfg.fsdp.use:
|
425 |
+
self.logger.info("Re-initializing best state after using FSDP best state.")
|
426 |
+
for name in self.best_state.states.keys():
|
427 |
+
state_source = self._get_state_source(name)
|
428 |
+
self.best_state.update(name, state_source)
|
429 |
+
|
430 |
+
return state
|
431 |
+
|
432 |
+
def restore(self, load_best: bool = False, replay_metrics: bool = False,
|
433 |
+
ignore_state_keys: tp.List[str] = []) -> bool:
|
434 |
+
"""Restore the status of a solver for a given xp.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
load_best (bool): if `True`, load the best state from the checkpoint.
|
438 |
+
replay_metrics (bool): if `True`, logs all the metrics from past epochs.
|
439 |
+
ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
|
440 |
+
"""
|
441 |
+
self.logger.info("Restoring weights and history.")
|
442 |
+
restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
|
443 |
+
|
444 |
+
self.logger.info("Model hash: %s", model_hash(self.model))
|
445 |
+
|
446 |
+
if replay_metrics and len(self.history) > 0:
|
447 |
+
self.logger.info("Replaying past metrics...")
|
448 |
+
for epoch, stages in enumerate(self.history):
|
449 |
+
for stage_name, metrics in stages.items():
|
450 |
+
# We manually log the metrics summary to the result logger
|
451 |
+
# as we don't want to add them to the pending metrics
|
452 |
+
self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
|
453 |
+
formatter=self.get_formatter(stage_name))
|
454 |
+
return restored_checkpoints is not None
|
455 |
+
|
456 |
+
def commit(self, save_checkpoints: bool = True):
|
457 |
+
"""Commit metrics to dora and save checkpoints at the end of an epoch."""
|
458 |
+
# we override commit to introduce more complex checkpoint saving behaviors
|
459 |
+
self.history.append(self._pending_metrics) # This will increase self.epoch
|
460 |
+
if save_checkpoints:
|
461 |
+
self.save_checkpoints()
|
462 |
+
self._start_epoch()
|
463 |
+
if flashy.distrib.is_rank_zero():
|
464 |
+
self.xp.link.update_history(self.history)
|
465 |
+
|
466 |
+
def run_epoch(self):
|
467 |
+
"""Run a single epoch with all stages.
|
468 |
+
|
469 |
+
Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
|
470 |
+
Children solvers can extend this method with custom behavior, e.g.:
|
471 |
+
|
472 |
+
def run_epoch(self):
|
473 |
+
... # custom code
|
474 |
+
super().run_epoch()
|
475 |
+
... # custom code
|
476 |
+
"""
|
477 |
+
self.run_stage('train', self.train)
|
478 |
+
with torch.no_grad():
|
479 |
+
with self.swap_ema_state():
|
480 |
+
self.run_stage('valid', self.valid)
|
481 |
+
# the best state is updated with EMA states if available
|
482 |
+
self.update_best_state_from_stage('valid')
|
483 |
+
with self.swap_best_state():
|
484 |
+
if self.should_run_stage('evaluate'):
|
485 |
+
self.run_stage('evaluate', self.evaluate)
|
486 |
+
if self.should_run_stage('generate'):
|
487 |
+
self.run_stage('generate', with_rank_rng()(self.generate))
|
488 |
+
|
489 |
+
def run(self):
|
490 |
+
"""Training loop."""
|
491 |
+
assert len(self.state_dict()) > 0
|
492 |
+
self.restore(replay_metrics=True) # load checkpoint and replay history
|
493 |
+
self.log_hyperparams(dict_from_config(self.cfg))
|
494 |
+
for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
|
495 |
+
if self.should_stop_training():
|
496 |
+
return
|
497 |
+
self.run_epoch()
|
498 |
+
# Commit will send the metrics to Dora and save checkpoints by default.
|
499 |
+
self.commit()
|
500 |
+
|
501 |
+
def should_stop_training(self) -> bool:
|
502 |
+
"""Check whether we should stop training or not."""
|
503 |
+
return self.epoch > self.cfg.optim.epochs
|
504 |
+
|
505 |
+
def should_run_stage(self, stage_name) -> bool:
|
506 |
+
"""Check whether we want to run the specified stages."""
|
507 |
+
stage_every = self.cfg[stage_name].get('every', None)
|
508 |
+
is_last_epoch = self.epoch == self.cfg.optim.epochs
|
509 |
+
is_epoch_every = (stage_every and self.epoch % stage_every == 0)
|
510 |
+
return is_last_epoch or is_epoch_every
|
511 |
+
|
512 |
+
@abstractmethod
|
513 |
+
def run_step(self, idx: int, batch: tp.Any, metrics: dict):
|
514 |
+
"""Perform one training or valid step on a given batch."""
|
515 |
+
...
|
516 |
+
|
517 |
+
def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
|
518 |
+
"""Common logic for train and valid stages."""
|
519 |
+
self.model.train(self.is_training)
|
520 |
+
|
521 |
+
loader = self.dataloaders[dataset_split]
|
522 |
+
# get a different order for distributed training, otherwise this will get ignored
|
523 |
+
if flashy.distrib.world_size() > 1 \
|
524 |
+
and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
|
525 |
+
loader.sampler.set_epoch(self.epoch)
|
526 |
+
updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
|
527 |
+
if self.cfg.benchmark_no_load:
|
528 |
+
self.logger.warning("Fake loading for benchmarking: re-using first batch")
|
529 |
+
batch = next(iter(loader))
|
530 |
+
loader = [batch] * updates_per_epoch # type: ignore
|
531 |
+
lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
|
532 |
+
average = flashy.averager() # epoch wise average
|
533 |
+
instant_average = flashy.averager() # average between two logging
|
534 |
+
metrics: dict = {}
|
535 |
+
|
536 |
+
with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates.
|
537 |
+
for idx, batch in enumerate(lp):
|
538 |
+
self.deadlock_detect.update('batch')
|
539 |
+
if idx >= updates_per_epoch:
|
540 |
+
break
|
541 |
+
metrics = {}
|
542 |
+
metrics = self.run_step(idx, batch, metrics)
|
543 |
+
self.deadlock_detect.update('step')
|
544 |
+
# run EMA step
|
545 |
+
if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
|
546 |
+
self.logger.debug("EMA model step")
|
547 |
+
self.ema.step()
|
548 |
+
self.deadlock_detect.update('ema')
|
549 |
+
self.profiler.step()
|
550 |
+
instant_metrics = instant_average(metrics)
|
551 |
+
if lp.update(**instant_metrics):
|
552 |
+
instant_average = flashy.averager() # reset averager between two logging
|
553 |
+
metrics = average(metrics) # epoch wise average
|
554 |
+
self.deadlock_detect.update('end_batch')
|
555 |
+
|
556 |
+
metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
|
557 |
+
return metrics
|
558 |
+
|
559 |
+
def train(self):
|
560 |
+
"""Train stage."""
|
561 |
+
return self.common_train_valid('train')
|
562 |
+
|
563 |
+
def valid(self):
|
564 |
+
"""Valid stage."""
|
565 |
+
return self.common_train_valid('valid')
|
566 |
+
|
567 |
+
@abstractmethod
|
568 |
+
def evaluate(self):
|
569 |
+
"""Evaluate stage."""
|
570 |
+
...
|
571 |
+
|
572 |
+
@abstractmethod
|
573 |
+
def generate(self):
|
574 |
+
"""Generate stage."""
|
575 |
+
...
|
576 |
+
|
577 |
+
def run_one_stage(self, stage_name: str):
|
578 |
+
"""Run only the specified stage.
|
579 |
+
This method is useful to only generate samples from a trained experiment
|
580 |
+
or rerun the validation or evaluation stages.
|
581 |
+
"""
|
582 |
+
fn = {
|
583 |
+
'generate': with_rank_rng()(self.generate),
|
584 |
+
'evaluate': self.evaluate,
|
585 |
+
'valid': self.valid,
|
586 |
+
}
|
587 |
+
if stage_name not in fn:
|
588 |
+
raise ValueError(f'Trying to run stage {stage_name} is not supported.')
|
589 |
+
assert len(self.state_dict()) > 0
|
590 |
+
self._start_epoch()
|
591 |
+
with torch.no_grad(), self.swap_best_state():
|
592 |
+
self.run_stage(stage_name, fn[stage_name])
|
593 |
+
if not self.cfg.execute_inplace:
|
594 |
+
self.commit(save_checkpoints=False)
|
595 |
+
|
596 |
+
@staticmethod
|
597 |
+
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
|
598 |
+
device: tp.Optional[str] = None, autocast: bool = True,
|
599 |
+
batch_size: tp.Optional[int] = None,
|
600 |
+
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
|
601 |
+
**kwargs):
|
602 |
+
"""Mostly a convenience function around audiocraft.train.get_solver_from_sig,
|
603 |
+
populating all the proper param, deactivating EMA, FSDP, loading the best state,
|
604 |
+
basically all you need to get a solver ready to "play" with in single GPU mode
|
605 |
+
and with minimal memory overhead.
|
606 |
+
|
607 |
+
Args:
|
608 |
+
sig (str): signature to load.
|
609 |
+
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
|
610 |
+
device (str or None): potential device, as a string, i.e. 'cuda'.
|
611 |
+
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
|
612 |
+
"""
|
613 |
+
from audiocraft import train
|
614 |
+
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
|
615 |
+
our_override_cfg['autocast'] = autocast
|
616 |
+
if dtype is not None:
|
617 |
+
our_override_cfg['dtype'] = dtype
|
618 |
+
if device is not None:
|
619 |
+
our_override_cfg['device'] = device
|
620 |
+
if batch_size is not None:
|
621 |
+
our_override_cfg['dataset'] = {'batch_size': batch_size}
|
622 |
+
if override_cfg is None:
|
623 |
+
override_cfg = {}
|
624 |
+
override_cfg = omegaconf.OmegaConf.merge(
|
625 |
+
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
|
626 |
+
solver = train.get_solver_from_sig(
|
627 |
+
sig, override_cfg=override_cfg,
|
628 |
+
load_best=True, disable_fsdp=True,
|
629 |
+
ignore_state_keys=['optimizer', 'ema'], **kwargs)
|
630 |
+
solver.model.eval()
|
631 |
+
return solver
|
audiocraft/solvers/builders.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
All the functions to build the relevant solvers and used objects
|
9 |
+
from the Hydra config.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from enum import Enum
|
13 |
+
import logging
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import dora
|
17 |
+
import flashy
|
18 |
+
import omegaconf
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torch.optim import Optimizer
|
22 |
+
# LRScheduler was renamed in some torch versions
|
23 |
+
try:
|
24 |
+
from torch.optim.lr_scheduler import LRScheduler # type: ignore
|
25 |
+
except ImportError:
|
26 |
+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
27 |
+
|
28 |
+
from .base import StandardSolver
|
29 |
+
from .. import adversarial, data, losses, metrics, optim
|
30 |
+
from ..utils.utils import dict_from_config, get_loader
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class DatasetType(Enum):
|
37 |
+
AUDIO = "audio"
|
38 |
+
MUSIC = "music"
|
39 |
+
SOUND = "sound"
|
40 |
+
|
41 |
+
|
42 |
+
def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
|
43 |
+
"""Instantiate solver from config."""
|
44 |
+
from .audiogen import AudioGenSolver
|
45 |
+
from .compression import CompressionSolver
|
46 |
+
from .musicgen import MusicGenSolver
|
47 |
+
from .diffusion import DiffusionSolver
|
48 |
+
from .magnet import MagnetSolver, AudioMagnetSolver
|
49 |
+
klass = {
|
50 |
+
'compression': CompressionSolver,
|
51 |
+
'musicgen': MusicGenSolver,
|
52 |
+
'audiogen': AudioGenSolver,
|
53 |
+
'magnet': MagnetSolver,
|
54 |
+
'audio_magnet': AudioMagnetSolver,
|
55 |
+
'lm': MusicGenSolver, # backward compatibility
|
56 |
+
'diffusion': DiffusionSolver,
|
57 |
+
'sound_lm': AudioGenSolver, # backward compatibility
|
58 |
+
}[cfg.solver]
|
59 |
+
return klass(cfg) # type: ignore
|
60 |
+
|
61 |
+
|
62 |
+
def get_optim_parameter_groups(model: nn.Module):
|
63 |
+
"""Create parameter groups for the model using the appropriate method
|
64 |
+
if defined for each modules, to create the different groups.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
model (nn.Module): torch model
|
68 |
+
Returns:
|
69 |
+
List of parameter groups
|
70 |
+
"""
|
71 |
+
seen_params: tp.Set[nn.parameter.Parameter] = set()
|
72 |
+
other_params = []
|
73 |
+
groups = []
|
74 |
+
for name, module in model.named_modules():
|
75 |
+
if hasattr(module, 'make_optim_group'):
|
76 |
+
group = module.make_optim_group()
|
77 |
+
params = set(group['params'])
|
78 |
+
assert params.isdisjoint(seen_params)
|
79 |
+
seen_params |= set(params)
|
80 |
+
groups.append(group)
|
81 |
+
for param in model.parameters():
|
82 |
+
if param not in seen_params:
|
83 |
+
other_params.append(param)
|
84 |
+
groups.insert(0, {'params': other_params})
|
85 |
+
parameters = groups
|
86 |
+
return parameters
|
87 |
+
|
88 |
+
|
89 |
+
def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
|
90 |
+
"""Build torch optimizer from config and set of parameters.
|
91 |
+
Supported optimizers: Adam, AdamW
|
92 |
+
|
93 |
+
Args:
|
94 |
+
params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
|
95 |
+
cfg (DictConfig): Optimization-related configuration.
|
96 |
+
Returns:
|
97 |
+
torch.optim.Optimizer.
|
98 |
+
"""
|
99 |
+
if 'optimizer' not in cfg:
|
100 |
+
if getattr(cfg, 'optim', None) is not None:
|
101 |
+
raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
|
102 |
+
else:
|
103 |
+
raise KeyError("Optimizer not found in config.")
|
104 |
+
|
105 |
+
parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
|
106 |
+
optimizer: torch.optim.Optimizer
|
107 |
+
if cfg.optimizer == 'adam':
|
108 |
+
optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
|
109 |
+
elif cfg.optimizer == 'adamw':
|
110 |
+
optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
|
111 |
+
elif cfg.optimizer == 'dadam':
|
112 |
+
optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Unsupported Optimizer: {cfg.optimizer}")
|
115 |
+
return optimizer
|
116 |
+
|
117 |
+
|
118 |
+
def get_lr_scheduler(optimizer: torch.optim.Optimizer,
|
119 |
+
cfg: omegaconf.DictConfig,
|
120 |
+
total_updates: int) -> tp.Optional[LRScheduler]:
|
121 |
+
"""Build torch learning rate scheduler from config and associated optimizer.
|
122 |
+
Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
|
123 |
+
|
124 |
+
Args:
|
125 |
+
optimizer (torch.optim.Optimizer): Optimizer.
|
126 |
+
cfg (DictConfig): Schedule-related configuration.
|
127 |
+
total_updates (int): Total number of updates.
|
128 |
+
Returns:
|
129 |
+
torch.optim.Optimizer.
|
130 |
+
"""
|
131 |
+
if 'lr_scheduler' not in cfg:
|
132 |
+
raise KeyError("LR Scheduler not found in config")
|
133 |
+
|
134 |
+
lr_sched: tp.Optional[LRScheduler] = None
|
135 |
+
if cfg.lr_scheduler == 'step':
|
136 |
+
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
|
137 |
+
elif cfg.lr_scheduler == 'exponential':
|
138 |
+
lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
|
139 |
+
elif cfg.lr_scheduler == 'cosine':
|
140 |
+
kwargs = dict_from_config(cfg.cosine)
|
141 |
+
warmup_steps = kwargs.pop('warmup')
|
142 |
+
lr_sched = optim.CosineLRScheduler(
|
143 |
+
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
|
144 |
+
elif cfg.lr_scheduler == 'polynomial_decay':
|
145 |
+
kwargs = dict_from_config(cfg.polynomial_decay)
|
146 |
+
warmup_steps = kwargs.pop('warmup')
|
147 |
+
lr_sched = optim.PolynomialDecayLRScheduler(
|
148 |
+
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
|
149 |
+
elif cfg.lr_scheduler == 'inverse_sqrt':
|
150 |
+
kwargs = dict_from_config(cfg.inverse_sqrt)
|
151 |
+
warmup_steps = kwargs.pop('warmup')
|
152 |
+
lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
|
153 |
+
elif cfg.lr_scheduler == 'linear_warmup':
|
154 |
+
kwargs = dict_from_config(cfg.linear_warmup)
|
155 |
+
warmup_steps = kwargs.pop('warmup')
|
156 |
+
lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
|
157 |
+
elif cfg.lr_scheduler is not None:
|
158 |
+
raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
|
159 |
+
return lr_sched
|
160 |
+
|
161 |
+
|
162 |
+
def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
|
163 |
+
"""Initialize Exponential Moving Average.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
|
167 |
+
cfg (omegaconf.DictConfig): Optim EMA configuration.
|
168 |
+
Returns:
|
169 |
+
optim.ModuleDictEMA: EMA version of the ModuleDict.
|
170 |
+
"""
|
171 |
+
kw: tp.Dict[str, tp.Any] = dict(cfg)
|
172 |
+
use = kw.pop('use', False)
|
173 |
+
decay = kw.pop('decay', None)
|
174 |
+
device = kw.pop('device', None)
|
175 |
+
if not use:
|
176 |
+
return None
|
177 |
+
if len(module_dict) == 0:
|
178 |
+
raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
|
179 |
+
ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
|
180 |
+
return ema_module
|
181 |
+
|
182 |
+
|
183 |
+
def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
|
184 |
+
"""Instantiate loss from configuration."""
|
185 |
+
klass = {
|
186 |
+
'l1': torch.nn.L1Loss,
|
187 |
+
'l2': torch.nn.MSELoss,
|
188 |
+
'mel': losses.MelSpectrogramL1Loss,
|
189 |
+
'mrstft': losses.MRSTFTLoss,
|
190 |
+
'msspec': losses.MultiScaleMelSpectrogramLoss,
|
191 |
+
'sisnr': losses.SISNR,
|
192 |
+
}[loss_name]
|
193 |
+
kwargs = dict(getattr(cfg, loss_name))
|
194 |
+
return klass(**kwargs)
|
195 |
+
|
196 |
+
|
197 |
+
def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
|
198 |
+
"""Instantiate loss balancer from configuration for the provided weights."""
|
199 |
+
kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
|
200 |
+
return losses.Balancer(loss_weights, **kwargs)
|
201 |
+
|
202 |
+
|
203 |
+
def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
|
204 |
+
"""Initialize adversary from config."""
|
205 |
+
klass = {
|
206 |
+
'msd': adversarial.MultiScaleDiscriminator,
|
207 |
+
'mpd': adversarial.MultiPeriodDiscriminator,
|
208 |
+
'msstftd': adversarial.MultiScaleSTFTDiscriminator,
|
209 |
+
}[name]
|
210 |
+
adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
|
211 |
+
return klass(**adv_cfg)
|
212 |
+
|
213 |
+
|
214 |
+
def get_adversarial_losses(cfg) -> nn.ModuleDict:
|
215 |
+
"""Initialize dict of adversarial losses from config."""
|
216 |
+
device = cfg.device
|
217 |
+
adv_cfg = getattr(cfg, 'adversarial')
|
218 |
+
adversaries = adv_cfg.get('adversaries', [])
|
219 |
+
adv_loss_name = adv_cfg['adv_loss']
|
220 |
+
feat_loss_name = adv_cfg.get('feat_loss')
|
221 |
+
normalize = adv_cfg.get('normalize', True)
|
222 |
+
feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
|
223 |
+
if feat_loss_name:
|
224 |
+
assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
|
225 |
+
loss = get_loss(feat_loss_name, cfg)
|
226 |
+
feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
|
227 |
+
loss = adversarial.get_adv_criterion(adv_loss_name)
|
228 |
+
loss_real = adversarial.get_real_criterion(adv_loss_name)
|
229 |
+
loss_fake = adversarial.get_fake_criterion(adv_loss_name)
|
230 |
+
adv_losses = nn.ModuleDict()
|
231 |
+
for adv_name in adversaries:
|
232 |
+
adversary = get_adversary(adv_name, cfg).to(device)
|
233 |
+
optimizer = get_optimizer(adversary.parameters(), cfg.optim)
|
234 |
+
adv_loss = adversarial.AdversarialLoss(
|
235 |
+
adversary,
|
236 |
+
optimizer,
|
237 |
+
loss=loss,
|
238 |
+
loss_real=loss_real,
|
239 |
+
loss_fake=loss_fake,
|
240 |
+
loss_feat=feat_loss,
|
241 |
+
normalize=normalize
|
242 |
+
)
|
243 |
+
adv_losses[adv_name] = adv_loss
|
244 |
+
return adv_losses
|
245 |
+
|
246 |
+
|
247 |
+
def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
|
248 |
+
"""Instantiate ViSQOL metric from config."""
|
249 |
+
kwargs = dict_from_config(cfg)
|
250 |
+
return metrics.ViSQOL(**kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
|
254 |
+
"""Instantiate Frechet Audio Distance metric from config."""
|
255 |
+
kwargs = dict_from_config(cfg.tf)
|
256 |
+
xp = dora.get_xp()
|
257 |
+
kwargs['log_folder'] = xp.folder
|
258 |
+
return metrics.FrechetAudioDistanceMetric(**kwargs)
|
259 |
+
|
260 |
+
|
261 |
+
def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
|
262 |
+
"""Instantiate KL-Divergence metric from config."""
|
263 |
+
kld_metrics = {
|
264 |
+
'passt': metrics.PasstKLDivergenceMetric,
|
265 |
+
}
|
266 |
+
klass = kld_metrics[cfg.model]
|
267 |
+
kwargs = dict_from_config(cfg.get(cfg.model))
|
268 |
+
return klass(**kwargs)
|
269 |
+
|
270 |
+
|
271 |
+
def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
|
272 |
+
"""Instantiate Text Consistency metric from config."""
|
273 |
+
text_consistency_metrics = {
|
274 |
+
'clap': metrics.CLAPTextConsistencyMetric
|
275 |
+
}
|
276 |
+
klass = text_consistency_metrics[cfg.model]
|
277 |
+
kwargs = dict_from_config(cfg.get(cfg.model))
|
278 |
+
return klass(**kwargs)
|
279 |
+
|
280 |
+
|
281 |
+
def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
|
282 |
+
"""Instantiate Chroma Cosine Similarity metric from config."""
|
283 |
+
assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
|
284 |
+
kwargs = dict_from_config(cfg.get(cfg.model))
|
285 |
+
return metrics.ChromaCosineSimilarityMetric(**kwargs)
|
286 |
+
|
287 |
+
|
288 |
+
def get_audio_datasets(cfg: omegaconf.DictConfig,
|
289 |
+
dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
|
290 |
+
"""Build AudioDataset from configuration.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
cfg (omegaconf.DictConfig): Configuration.
|
294 |
+
dataset_type: The type of dataset to create.
|
295 |
+
Returns:
|
296 |
+
dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
|
297 |
+
"""
|
298 |
+
dataloaders: dict = {}
|
299 |
+
|
300 |
+
sample_rate = cfg.sample_rate
|
301 |
+
channels = cfg.channels
|
302 |
+
seed = cfg.seed
|
303 |
+
max_sample_rate = cfg.datasource.max_sample_rate
|
304 |
+
max_channels = cfg.datasource.max_channels
|
305 |
+
|
306 |
+
assert cfg.dataset is not None, "Could not find dataset definition in config"
|
307 |
+
|
308 |
+
dataset_cfg = dict_from_config(cfg.dataset)
|
309 |
+
splits_cfg: dict = {}
|
310 |
+
splits_cfg['train'] = dataset_cfg.pop('train')
|
311 |
+
splits_cfg['valid'] = dataset_cfg.pop('valid')
|
312 |
+
splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
|
313 |
+
splits_cfg['generate'] = dataset_cfg.pop('generate')
|
314 |
+
execute_only_stage = cfg.get('execute_only', None)
|
315 |
+
|
316 |
+
for split, path in cfg.datasource.items():
|
317 |
+
if not isinstance(path, str):
|
318 |
+
continue # skipping this as not a path
|
319 |
+
if execute_only_stage is not None and split != execute_only_stage:
|
320 |
+
continue
|
321 |
+
logger.info(f"Loading audio data split {split}: {str(path)}")
|
322 |
+
assert (
|
323 |
+
cfg.sample_rate <= max_sample_rate
|
324 |
+
), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
|
325 |
+
assert (
|
326 |
+
cfg.channels <= max_channels
|
327 |
+
), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
|
328 |
+
|
329 |
+
split_cfg = splits_cfg[split]
|
330 |
+
split_kwargs = {k: v for k, v in split_cfg.items()}
|
331 |
+
kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg
|
332 |
+
kwargs['sample_rate'] = sample_rate
|
333 |
+
kwargs['channels'] = channels
|
334 |
+
|
335 |
+
if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
|
336 |
+
kwargs['num_samples'] = (
|
337 |
+
flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
|
338 |
+
|
339 |
+
num_samples = kwargs['num_samples']
|
340 |
+
shuffle = kwargs['shuffle']
|
341 |
+
|
342 |
+
return_info = kwargs.pop('return_info')
|
343 |
+
batch_size = kwargs.pop('batch_size', None)
|
344 |
+
num_workers = kwargs.pop('num_workers')
|
345 |
+
|
346 |
+
if dataset_type == DatasetType.MUSIC:
|
347 |
+
dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
|
348 |
+
elif dataset_type == DatasetType.SOUND:
|
349 |
+
dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
|
350 |
+
elif dataset_type == DatasetType.AUDIO:
|
351 |
+
dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
|
352 |
+
else:
|
353 |
+
raise ValueError(f"Dataset type is unsupported: {dataset_type}")
|
354 |
+
|
355 |
+
loader = get_loader(
|
356 |
+
dataset,
|
357 |
+
num_samples,
|
358 |
+
batch_size=batch_size,
|
359 |
+
num_workers=num_workers,
|
360 |
+
seed=seed,
|
361 |
+
collate_fn=dataset.collater if return_info else None,
|
362 |
+
shuffle=shuffle,
|
363 |
+
)
|
364 |
+
dataloaders[split] = loader
|
365 |
+
|
366 |
+
return dataloaders
|
audiocraft/solvers/compression.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import multiprocessing
|
9 |
+
from pathlib import Path
|
10 |
+
import typing as tp
|
11 |
+
|
12 |
+
import flashy
|
13 |
+
import omegaconf
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from . import base, builders
|
18 |
+
from .. import models, quantization
|
19 |
+
from ..utils import checkpoint
|
20 |
+
from ..utils.samples.manager import SampleManager
|
21 |
+
from ..utils.utils import get_pool_executor
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class CompressionSolver(base.StandardSolver):
|
28 |
+
"""Solver for compression task.
|
29 |
+
|
30 |
+
The compression task combines a set of perceptual and objective losses
|
31 |
+
to train an EncodecModel (composed of an encoder-decoder and a quantizer)
|
32 |
+
to perform high fidelity audio reconstruction.
|
33 |
+
"""
|
34 |
+
def __init__(self, cfg: omegaconf.DictConfig):
|
35 |
+
super().__init__(cfg)
|
36 |
+
self.rng: torch.Generator # set at each epoch
|
37 |
+
self.adv_losses = builders.get_adversarial_losses(self.cfg)
|
38 |
+
self.aux_losses = nn.ModuleDict()
|
39 |
+
self.info_losses = nn.ModuleDict()
|
40 |
+
assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
|
41 |
+
loss_weights = dict()
|
42 |
+
for loss_name, weight in self.cfg.losses.items():
|
43 |
+
if loss_name in ['adv', 'feat']:
|
44 |
+
for adv_name, _ in self.adv_losses.items():
|
45 |
+
loss_weights[f'{loss_name}_{adv_name}'] = weight
|
46 |
+
elif weight > 0:
|
47 |
+
self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
|
48 |
+
loss_weights[loss_name] = weight
|
49 |
+
else:
|
50 |
+
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
|
51 |
+
self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
|
52 |
+
self.register_stateful('adv_losses')
|
53 |
+
|
54 |
+
@property
|
55 |
+
def best_metric_name(self) -> tp.Optional[str]:
|
56 |
+
# best model is the last for the compression model
|
57 |
+
return None
|
58 |
+
|
59 |
+
def build_model(self):
|
60 |
+
"""Instantiate model and optimizer."""
|
61 |
+
# Model and optimizer
|
62 |
+
self.model = models.builders.get_compression_model(self.cfg).to(self.device)
|
63 |
+
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
|
64 |
+
self.register_stateful('model', 'optimizer')
|
65 |
+
self.register_best_state('model')
|
66 |
+
self.register_ema('model')
|
67 |
+
|
68 |
+
def build_dataloaders(self):
|
69 |
+
"""Instantiate audio dataloaders for each stage."""
|
70 |
+
self.dataloaders = builders.get_audio_datasets(self.cfg)
|
71 |
+
|
72 |
+
def show(self):
|
73 |
+
"""Show the compression model and employed adversarial loss."""
|
74 |
+
self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
|
75 |
+
self.log_model_summary(self.model)
|
76 |
+
self.logger.info("Adversarial loss:")
|
77 |
+
self.log_model_summary(self.adv_losses)
|
78 |
+
self.logger.info("Auxiliary losses:")
|
79 |
+
self.logger.info(self.aux_losses)
|
80 |
+
self.logger.info("Info losses:")
|
81 |
+
self.logger.info(self.info_losses)
|
82 |
+
|
83 |
+
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
|
84 |
+
"""Perform one training or valid step on a given batch."""
|
85 |
+
x = batch.to(self.device)
|
86 |
+
y = x.clone()
|
87 |
+
|
88 |
+
qres = self.model(x)
|
89 |
+
assert isinstance(qres, quantization.QuantizedResult)
|
90 |
+
y_pred = qres.x
|
91 |
+
# Log bandwidth in kb/s
|
92 |
+
metrics['bandwidth'] = qres.bandwidth.mean()
|
93 |
+
|
94 |
+
if self.is_training:
|
95 |
+
d_losses: dict = {}
|
96 |
+
if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
|
97 |
+
for adv_name, adversary in self.adv_losses.items():
|
98 |
+
disc_loss = adversary.train_adv(y_pred, y)
|
99 |
+
d_losses[f'd_{adv_name}'] = disc_loss
|
100 |
+
metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
|
101 |
+
metrics.update(d_losses)
|
102 |
+
|
103 |
+
balanced_losses: dict = {}
|
104 |
+
other_losses: dict = {}
|
105 |
+
|
106 |
+
# penalty from quantization
|
107 |
+
if qres.penalty is not None and qres.penalty.requires_grad:
|
108 |
+
other_losses['penalty'] = qres.penalty # penalty term from the quantizer
|
109 |
+
|
110 |
+
# adversarial losses
|
111 |
+
for adv_name, adversary in self.adv_losses.items():
|
112 |
+
adv_loss, feat_loss = adversary(y_pred, y)
|
113 |
+
balanced_losses[f'adv_{adv_name}'] = adv_loss
|
114 |
+
balanced_losses[f'feat_{adv_name}'] = feat_loss
|
115 |
+
|
116 |
+
# auxiliary losses
|
117 |
+
for loss_name, criterion in self.aux_losses.items():
|
118 |
+
loss = criterion(y_pred, y)
|
119 |
+
balanced_losses[loss_name] = loss
|
120 |
+
|
121 |
+
# weighted losses
|
122 |
+
metrics.update(balanced_losses)
|
123 |
+
metrics.update(other_losses)
|
124 |
+
metrics.update(qres.metrics)
|
125 |
+
|
126 |
+
if self.is_training:
|
127 |
+
# backprop losses that are not handled by balancer
|
128 |
+
other_loss = torch.tensor(0., device=self.device)
|
129 |
+
if 'penalty' in other_losses:
|
130 |
+
other_loss += other_losses['penalty']
|
131 |
+
if other_loss.requires_grad:
|
132 |
+
other_loss.backward(retain_graph=True)
|
133 |
+
ratio1 = sum(p.grad.data.norm(p=2).pow(2)
|
134 |
+
for p in self.model.parameters() if p.grad is not None)
|
135 |
+
assert isinstance(ratio1, torch.Tensor)
|
136 |
+
metrics['ratio1'] = ratio1.sqrt()
|
137 |
+
|
138 |
+
# balancer losses backward, returns effective training loss
|
139 |
+
# with effective weights at the current batch.
|
140 |
+
metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
|
141 |
+
# add metrics corresponding to weight ratios
|
142 |
+
metrics.update(self.balancer.metrics)
|
143 |
+
ratio2 = sum(p.grad.data.norm(p=2).pow(2)
|
144 |
+
for p in self.model.parameters() if p.grad is not None)
|
145 |
+
assert isinstance(ratio2, torch.Tensor)
|
146 |
+
metrics['ratio2'] = ratio2.sqrt()
|
147 |
+
|
148 |
+
# optim
|
149 |
+
flashy.distrib.sync_model(self.model)
|
150 |
+
if self.cfg.optim.max_norm:
|
151 |
+
torch.nn.utils.clip_grad_norm_(
|
152 |
+
self.model.parameters(), self.cfg.optim.max_norm
|
153 |
+
)
|
154 |
+
self.optimizer.step()
|
155 |
+
self.optimizer.zero_grad()
|
156 |
+
|
157 |
+
# informative losses only
|
158 |
+
info_losses: dict = {}
|
159 |
+
with torch.no_grad():
|
160 |
+
for loss_name, criterion in self.info_losses.items():
|
161 |
+
loss = criterion(y_pred, y)
|
162 |
+
info_losses[loss_name] = loss
|
163 |
+
|
164 |
+
metrics.update(info_losses)
|
165 |
+
|
166 |
+
# aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
|
167 |
+
adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
|
168 |
+
if len(adv_losses) > 0:
|
169 |
+
metrics['adv'] = torch.sum(torch.stack(adv_losses))
|
170 |
+
feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
|
171 |
+
if len(feat_losses) > 0:
|
172 |
+
metrics['feat'] = torch.sum(torch.stack(feat_losses))
|
173 |
+
|
174 |
+
return metrics
|
175 |
+
|
176 |
+
def run_epoch(self):
|
177 |
+
# reset random seed at the beginning of the epoch
|
178 |
+
self.rng = torch.Generator()
|
179 |
+
self.rng.manual_seed(1234 + self.epoch)
|
180 |
+
# run epoch
|
181 |
+
super().run_epoch()
|
182 |
+
|
183 |
+
def evaluate(self):
|
184 |
+
"""Evaluate stage. Runs audio reconstruction evaluation."""
|
185 |
+
self.model.eval()
|
186 |
+
evaluate_stage_name = str(self.current_stage)
|
187 |
+
|
188 |
+
loader = self.dataloaders['evaluate']
|
189 |
+
updates = len(loader)
|
190 |
+
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
|
191 |
+
average = flashy.averager()
|
192 |
+
|
193 |
+
pendings = []
|
194 |
+
ctx = multiprocessing.get_context('spawn')
|
195 |
+
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
|
196 |
+
for idx, batch in enumerate(lp):
|
197 |
+
x = batch.to(self.device)
|
198 |
+
with torch.no_grad():
|
199 |
+
qres = self.model(x)
|
200 |
+
|
201 |
+
y_pred = qres.x.cpu()
|
202 |
+
y = batch.cpu() # should already be on CPU but just in case
|
203 |
+
pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
|
204 |
+
|
205 |
+
metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
|
206 |
+
for pending in metrics_lp:
|
207 |
+
metrics = pending.result()
|
208 |
+
metrics = average(metrics)
|
209 |
+
|
210 |
+
metrics = flashy.distrib.average_metrics(metrics, len(loader))
|
211 |
+
return metrics
|
212 |
+
|
213 |
+
def generate(self):
|
214 |
+
"""Generate stage."""
|
215 |
+
self.model.eval()
|
216 |
+
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
|
217 |
+
generate_stage_name = str(self.current_stage)
|
218 |
+
|
219 |
+
loader = self.dataloaders['generate']
|
220 |
+
updates = len(loader)
|
221 |
+
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
|
222 |
+
|
223 |
+
for batch in lp:
|
224 |
+
reference, _ = batch
|
225 |
+
reference = reference.to(self.device)
|
226 |
+
with torch.no_grad():
|
227 |
+
qres = self.model(reference)
|
228 |
+
assert isinstance(qres, quantization.QuantizedResult)
|
229 |
+
|
230 |
+
reference = reference.cpu()
|
231 |
+
estimate = qres.x.cpu()
|
232 |
+
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
|
233 |
+
|
234 |
+
flashy.distrib.barrier()
|
235 |
+
|
236 |
+
def load_from_pretrained(self, name: str) -> dict:
|
237 |
+
model = models.CompressionModel.get_pretrained(name)
|
238 |
+
if isinstance(model, models.DAC):
|
239 |
+
raise RuntimeError("Cannot fine tune a DAC model.")
|
240 |
+
elif isinstance(model, models.HFEncodecCompressionModel):
|
241 |
+
self.logger.warning('Trying to automatically convert a HuggingFace model '
|
242 |
+
'to AudioCraft, this might fail!')
|
243 |
+
state = model.model.state_dict()
|
244 |
+
new_state = {}
|
245 |
+
for k, v in state.items():
|
246 |
+
if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
|
247 |
+
# We need to determine if this a convtr or a regular conv.
|
248 |
+
layer = int(k.split('.')[2])
|
249 |
+
if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
|
250 |
+
|
251 |
+
k = k.replace('.conv.', '.convtr.')
|
252 |
+
k = k.replace('encoder.layers.', 'encoder.model.')
|
253 |
+
k = k.replace('decoder.layers.', 'decoder.model.')
|
254 |
+
k = k.replace('conv.', 'conv.conv.')
|
255 |
+
k = k.replace('convtr.', 'convtr.convtr.')
|
256 |
+
k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
|
257 |
+
k = k.replace('.codebook.', '._codebook.')
|
258 |
+
new_state[k] = v
|
259 |
+
state = new_state
|
260 |
+
elif isinstance(model, models.EncodecModel):
|
261 |
+
state = model.state_dict()
|
262 |
+
else:
|
263 |
+
raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
|
264 |
+
return {
|
265 |
+
'best_state': {'model': state}
|
266 |
+
}
|
267 |
+
|
268 |
+
@staticmethod
|
269 |
+
def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
|
270 |
+
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
|
271 |
+
"""Instantiate a CompressionModel from a given checkpoint path or dora sig.
|
272 |
+
This method is a convenient endpoint to load a CompressionModel to use in other solvers.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
|
276 |
+
This also supports pre-trained models by using a path of the form //pretrained/NAME.
|
277 |
+
See `model_from_pretrained` for a list of supported pretrained models.
|
278 |
+
use_ema (bool): Use EMA variant of the model instead of the actual model.
|
279 |
+
device (torch.device or str): Device on which the model is loaded.
|
280 |
+
"""
|
281 |
+
checkpoint_path = str(checkpoint_path)
|
282 |
+
if checkpoint_path.startswith('//pretrained/'):
|
283 |
+
name = checkpoint_path.split('/', 3)[-1]
|
284 |
+
return models.CompressionModel.get_pretrained(name, device)
|
285 |
+
logger = logging.getLogger(__name__)
|
286 |
+
logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
|
287 |
+
_checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
|
288 |
+
assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
|
289 |
+
state = checkpoint.load_checkpoint(_checkpoint_path)
|
290 |
+
assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
|
291 |
+
cfg = state['xp.cfg']
|
292 |
+
cfg.device = device
|
293 |
+
compression_model = models.builders.get_compression_model(cfg).to(device)
|
294 |
+
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
295 |
+
|
296 |
+
assert 'best_state' in state and state['best_state'] != {}
|
297 |
+
assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
|
298 |
+
compression_model.load_state_dict(state['best_state']['model'])
|
299 |
+
compression_model.eval()
|
300 |
+
logger.info("Compression model loaded!")
|
301 |
+
return compression_model
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
|
305 |
+
checkpoint_path: tp.Union[Path, str],
|
306 |
+
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
|
307 |
+
"""Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
|
311 |
+
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
|
312 |
+
use_ema (bool): Use EMA variant of the model instead of the actual model.
|
313 |
+
device (torch.device or str): Device on which the model is loaded.
|
314 |
+
"""
|
315 |
+
compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
|
316 |
+
compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
|
317 |
+
return compression_model
|
318 |
+
|
319 |
+
|
320 |
+
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
|
321 |
+
"""Audio reconstruction evaluation method that can be conveniently pickled."""
|
322 |
+
metrics = {}
|
323 |
+
if cfg.evaluate.metrics.visqol:
|
324 |
+
visqol = builders.get_visqol(cfg.metrics.visqol)
|
325 |
+
metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
|
326 |
+
sisnr = builders.get_loss('sisnr', cfg)
|
327 |
+
metrics['sisnr'] = sisnr(y_pred, y)
|
328 |
+
return metrics
|
audiocraft/solvers/diffusion.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import flashy
|
10 |
+
import julius
|
11 |
+
import omegaconf
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from . import builders
|
16 |
+
from . import base
|
17 |
+
from .. import models
|
18 |
+
from ..modules.diffusion_schedule import NoiseSchedule
|
19 |
+
from ..metrics import RelativeVolumeMel
|
20 |
+
from ..models.builders import get_processor
|
21 |
+
from ..utils.samples.manager import SampleManager
|
22 |
+
from ..solvers.compression import CompressionSolver
|
23 |
+
|
24 |
+
|
25 |
+
class PerStageMetrics:
|
26 |
+
"""Handle prompting the metrics per stage.
|
27 |
+
It outputs the metrics per range of diffusion states.
|
28 |
+
e.g. avg loss when t in [250, 500]
|
29 |
+
"""
|
30 |
+
def __init__(self, num_steps: int, num_stages: int = 4):
|
31 |
+
self.num_steps = num_steps
|
32 |
+
self.num_stages = num_stages
|
33 |
+
|
34 |
+
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
|
35 |
+
if type(step) is int:
|
36 |
+
stage = int((step / self.num_steps) * self.num_stages)
|
37 |
+
return {f"{name}_{stage}": loss for name, loss in losses.items()}
|
38 |
+
elif type(step) is torch.Tensor:
|
39 |
+
stage_tensor = ((step / self.num_steps) * self.num_stages).long()
|
40 |
+
out: tp.Dict[str, float] = {}
|
41 |
+
for stage_idx in range(self.num_stages):
|
42 |
+
mask = (stage_tensor == stage_idx)
|
43 |
+
N = mask.sum()
|
44 |
+
stage_out = {}
|
45 |
+
if N > 0: # pass if no elements in the stage
|
46 |
+
for name, loss in losses.items():
|
47 |
+
stage_loss = (mask * loss).sum() / N
|
48 |
+
stage_out[f"{name}_{stage_idx}"] = stage_loss
|
49 |
+
out = {**out, **stage_out}
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class DataProcess:
|
54 |
+
"""Apply filtering or resampling.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
initial_sr (int): Initial sample rate.
|
58 |
+
target_sr (int): Target sample rate.
|
59 |
+
use_resampling: Whether to use resampling or not.
|
60 |
+
use_filter (bool):
|
61 |
+
n_bands (int): Number of bands to consider.
|
62 |
+
idx_band (int):
|
63 |
+
device (torch.device or str):
|
64 |
+
cutoffs ():
|
65 |
+
boost (bool):
|
66 |
+
"""
|
67 |
+
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
|
68 |
+
use_filter: bool = False, n_bands: int = 4,
|
69 |
+
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
|
70 |
+
"""Apply filtering or resampling
|
71 |
+
Args:
|
72 |
+
initial_sr (int): sample rate of the dataset
|
73 |
+
target_sr (int): sample rate after resampling
|
74 |
+
use_resampling (bool): whether or not performs resampling
|
75 |
+
use_filter (bool): when True filter the data to keep only one frequency band
|
76 |
+
n_bands (int): Number of bands used
|
77 |
+
cuts (none or list): The cutoff frequencies of the band filtering
|
78 |
+
if None then we use mel scale bands.
|
79 |
+
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
|
80 |
+
boost (bool): make the data scale match our music dataset.
|
81 |
+
"""
|
82 |
+
assert idx_band < n_bands
|
83 |
+
self.idx_band = idx_band
|
84 |
+
if use_filter:
|
85 |
+
if cutoffs is not None:
|
86 |
+
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
|
87 |
+
else:
|
88 |
+
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
|
89 |
+
self.use_filter = use_filter
|
90 |
+
self.use_resampling = use_resampling
|
91 |
+
self.target_sr = target_sr
|
92 |
+
self.initial_sr = initial_sr
|
93 |
+
self.boost = boost
|
94 |
+
|
95 |
+
def process_data(self, x, metric=False):
|
96 |
+
if x is None:
|
97 |
+
return None
|
98 |
+
if self.boost:
|
99 |
+
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
|
100 |
+
x * 0.22
|
101 |
+
if self.use_filter and not metric:
|
102 |
+
x = self.filter(x)[self.idx_band]
|
103 |
+
if self.use_resampling:
|
104 |
+
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
|
105 |
+
return x
|
106 |
+
|
107 |
+
def inverse_process(self, x):
|
108 |
+
"""Upsampling only."""
|
109 |
+
if self.use_resampling:
|
110 |
+
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
class DiffusionSolver(base.StandardSolver):
|
115 |
+
"""Solver for compression task.
|
116 |
+
|
117 |
+
The diffusion task allows for MultiBand diffusion model training.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
cfg (DictConfig): Configuration.
|
121 |
+
"""
|
122 |
+
def __init__(self, cfg: omegaconf.DictConfig):
|
123 |
+
super().__init__(cfg)
|
124 |
+
self.cfg = cfg
|
125 |
+
self.device = cfg.device
|
126 |
+
self.sample_rate: int = self.cfg.sample_rate
|
127 |
+
self.codec_model = CompressionSolver.model_from_checkpoint(
|
128 |
+
cfg.compression_model_checkpoint, device=self.device)
|
129 |
+
|
130 |
+
self.codec_model.set_num_codebooks(cfg.n_q)
|
131 |
+
assert self.codec_model.sample_rate == self.cfg.sample_rate, (
|
132 |
+
f"Codec model sample rate is {self.codec_model.sample_rate} but "
|
133 |
+
f"Solver sample rate is {self.cfg.sample_rate}."
|
134 |
+
)
|
135 |
+
assert self.codec_model.sample_rate == self.sample_rate, \
|
136 |
+
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
|
137 |
+
"don't match."
|
138 |
+
|
139 |
+
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
|
140 |
+
self.register_stateful('sample_processor')
|
141 |
+
self.sample_processor.to(self.device)
|
142 |
+
|
143 |
+
self.schedule = NoiseSchedule(
|
144 |
+
**cfg.schedule, device=self.device, sample_processor=self.sample_processor)
|
145 |
+
|
146 |
+
self.eval_metric: tp.Optional[torch.nn.Module] = None
|
147 |
+
|
148 |
+
self.rvm = RelativeVolumeMel()
|
149 |
+
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
|
150 |
+
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
|
151 |
+
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
|
152 |
+
idx_band=cfg.filter.idx_band, device=self.device)
|
153 |
+
|
154 |
+
@property
|
155 |
+
def best_metric_name(self) -> tp.Optional[str]:
|
156 |
+
if self._current_stage == "evaluate":
|
157 |
+
return 'rvm'
|
158 |
+
else:
|
159 |
+
return 'loss'
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
|
163 |
+
codes, scale = self.codec_model.encode(wav)
|
164 |
+
assert scale is None, "Scaled compression models not supported."
|
165 |
+
emb = self.codec_model.decode_latent(codes)
|
166 |
+
return emb
|
167 |
+
|
168 |
+
def build_model(self):
|
169 |
+
"""Build model and optimizer as well as optional Exponential Moving Average of the model.
|
170 |
+
"""
|
171 |
+
# Model and optimizer
|
172 |
+
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
|
173 |
+
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
|
174 |
+
self.register_stateful('model', 'optimizer')
|
175 |
+
self.register_best_state('model')
|
176 |
+
self.register_ema('model')
|
177 |
+
|
178 |
+
def build_dataloaders(self):
|
179 |
+
"""Build audio dataloaders for each stage."""
|
180 |
+
self.dataloaders = builders.get_audio_datasets(self.cfg)
|
181 |
+
|
182 |
+
def show(self):
|
183 |
+
# TODO
|
184 |
+
raise NotImplementedError()
|
185 |
+
|
186 |
+
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
|
187 |
+
"""Perform one training or valid step on a given batch."""
|
188 |
+
x = batch.to(self.device)
|
189 |
+
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
|
190 |
+
|
191 |
+
condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
|
192 |
+
sample = self.data_processor.process_data(x)
|
193 |
+
|
194 |
+
input_, target, step = self.schedule.get_training_item(sample,
|
195 |
+
tensor_step=self.cfg.schedule.variable_step_batch)
|
196 |
+
out = self.model(input_, step, condition=condition).sample
|
197 |
+
|
198 |
+
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
|
199 |
+
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
|
200 |
+
loss = base_loss / reference_loss ** self.cfg.loss.norm_power
|
201 |
+
|
202 |
+
if self.is_training:
|
203 |
+
loss.mean().backward()
|
204 |
+
flashy.distrib.sync_model(self.model)
|
205 |
+
self.optimizer.step()
|
206 |
+
self.optimizer.zero_grad()
|
207 |
+
metrics = {
|
208 |
+
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
|
209 |
+
}
|
210 |
+
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
|
211 |
+
metrics.update({
|
212 |
+
'std_in': input_.std(), 'std_out': out.std()})
|
213 |
+
return metrics
|
214 |
+
|
215 |
+
def run_epoch(self):
|
216 |
+
# reset random seed at the beginning of the epoch
|
217 |
+
self.rng = torch.Generator()
|
218 |
+
self.rng.manual_seed(1234 + self.epoch)
|
219 |
+
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
|
220 |
+
# run epoch
|
221 |
+
super().run_epoch()
|
222 |
+
|
223 |
+
def evaluate(self):
|
224 |
+
"""Evaluate stage.
|
225 |
+
Runs audio reconstruction evaluation.
|
226 |
+
"""
|
227 |
+
self.model.eval()
|
228 |
+
evaluate_stage_name = f'{self.current_stage}'
|
229 |
+
loader = self.dataloaders['evaluate']
|
230 |
+
updates = len(loader)
|
231 |
+
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
|
232 |
+
|
233 |
+
metrics = {}
|
234 |
+
n = 1
|
235 |
+
for idx, batch in enumerate(lp):
|
236 |
+
x = batch.to(self.device)
|
237 |
+
with torch.no_grad():
|
238 |
+
y_pred = self.regenerate(x)
|
239 |
+
|
240 |
+
y_pred = y_pred.cpu()
|
241 |
+
y = batch.cpu() # should already be on CPU but just in case
|
242 |
+
rvm = self.rvm(y_pred, y)
|
243 |
+
lp.update(**rvm)
|
244 |
+
if len(metrics) == 0:
|
245 |
+
metrics = rvm
|
246 |
+
else:
|
247 |
+
for key in rvm.keys():
|
248 |
+
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
|
249 |
+
metrics = flashy.distrib.average_metrics(metrics)
|
250 |
+
return metrics
|
251 |
+
|
252 |
+
@torch.no_grad()
|
253 |
+
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
|
254 |
+
"""Regenerate the given waveform."""
|
255 |
+
condition = self.get_condition(wav)
|
256 |
+
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
|
257 |
+
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
|
258 |
+
step_list=step_list)
|
259 |
+
result = self.data_processor.inverse_process(result)
|
260 |
+
return result
|
261 |
+
|
262 |
+
def generate(self):
|
263 |
+
"""Generate stage."""
|
264 |
+
sample_manager = SampleManager(self.xp)
|
265 |
+
self.model.eval()
|
266 |
+
generate_stage_name = f'{self.current_stage}'
|
267 |
+
|
268 |
+
loader = self.dataloaders['generate']
|
269 |
+
updates = len(loader)
|
270 |
+
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
|
271 |
+
|
272 |
+
for batch in lp:
|
273 |
+
reference, _ = batch
|
274 |
+
reference = reference.to(self.device)
|
275 |
+
estimate = self.regenerate(reference)
|
276 |
+
reference = reference.cpu()
|
277 |
+
estimate = estimate.cpu()
|
278 |
+
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
|
279 |
+
flashy.distrib.barrier()
|
audiocraft/solvers/magnet.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from omegaconf import DictConfig
|
8 |
+
from . import builders, musicgen
|
9 |
+
from einops import rearrange
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from ..modules.conditioners import SegmentWithAttributes
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
import random
|
16 |
+
import typing as tp
|
17 |
+
import math
|
18 |
+
import flashy
|
19 |
+
|
20 |
+
|
21 |
+
class MagnetSolver(musicgen.MusicGenSolver):
|
22 |
+
"""Solver for MAGNeT - Masked Audio Generation using
|
23 |
+
a single Non-autoregressive Transformer https://arxiv.org/abs/2401.04577.
|
24 |
+
"""
|
25 |
+
def __init__(self, cfg: DictConfig):
|
26 |
+
super().__init__(cfg)
|
27 |
+
|
28 |
+
# initialize generation parameters by config
|
29 |
+
self.generation_params = {
|
30 |
+
'use_sampling': self.cfg.generate.lm.use_sampling,
|
31 |
+
'temp': self.cfg.generate.lm.temp,
|
32 |
+
'top_k': self.cfg.generate.lm.top_k,
|
33 |
+
'top_p': self.cfg.generate.lm.top_p,
|
34 |
+
'max_cfg_coef': self.cfg.generate.lm.max_cfg_coef,
|
35 |
+
'min_cfg_coef': self.cfg.generate.lm.min_cfg_coef,
|
36 |
+
'decoding_steps': list(self.cfg.generate.lm.decoding_steps),
|
37 |
+
'anneal_temp': self.cfg.generate.lm.anneal_temp,
|
38 |
+
'span_scoring': self.cfg.generate.lm.span_scoring,
|
39 |
+
'span_arrangement': self.cfg.generate.lm.span_arrangement
|
40 |
+
}
|
41 |
+
|
42 |
+
sequence_len = int(cfg.dataset.segment_duration * self.compression_model.frame_rate)
|
43 |
+
self.mean_maskrate_to_u = torch.tensor(self._calc_mean_maskrate_to_u_LUT(sequence_len), device=self.device)
|
44 |
+
self.ce_per_codebook = [torch.log(torch.tensor(self.compression_model.cardinality, device=self.device))
|
45 |
+
for _ in range(cfg.transformer_lm.n_q)]
|
46 |
+
|
47 |
+
def build_model(self) -> None:
|
48 |
+
self.cfg.transformer_lm.segment_duration = self.cfg.dataset.segment_duration
|
49 |
+
self.cfg.transformer_lm.span_len = self.cfg.masking.span_len
|
50 |
+
assert self.cfg.efficient_attention_backend == "xformers", "MAGNeT v1 models support only xformers backend."
|
51 |
+
super().build_model()
|
52 |
+
|
53 |
+
def _calc_mean_maskrate_to_u_LUT(self, T: int):
|
54 |
+
""" Create a Look Up Table (LUT) transforming a discrete masking percentage m in 0,1,...,100 to u,
|
55 |
+
the number of overlapping spans of length L to place s.t. the masking rate is approximately m/float(100).
|
56 |
+
It first creates the inverse transformation, of the masking rate as function of u,
|
57 |
+
using the expression choose(T - L, u) / choose(T, u), where L is the atomic span length used
|
58 |
+
during masking. See https://arxiv.org/abs/2401.04577,
|
59 |
+
appendix C, for the mean mask rate derivation.
|
60 |
+
|
61 |
+
We leverage the fact that:
|
62 |
+
choose(T - L, u) / choose(T, u) = Prod_{j = 0}^{u - 1}((T - L - j)/(T - j))
|
63 |
+
in the provided implementation, in order to avoid overflow.
|
64 |
+
Args:
|
65 |
+
T (float): Sequence length.
|
66 |
+
Returns:
|
67 |
+
(List) A LUT transforming m in 0,1,...,100 to u,
|
68 |
+
s.t. the masking rate of the span-L mask is approximately m/float(100).
|
69 |
+
"""
|
70 |
+
|
71 |
+
L = self.cfg.masking.span_len
|
72 |
+
|
73 |
+
u2mean = [0.0] # mean mask rate is 0.0 for u = 0
|
74 |
+
v = (T - L) / float(T)
|
75 |
+
for u in range(1, T):
|
76 |
+
u2mean.append(1 - v)
|
77 |
+
v *= (T - L - u) / (T - u) # Overflow-safe implementation of choose(T - L, u) / choose(T, u).
|
78 |
+
|
79 |
+
mean2u = []
|
80 |
+
for maskperc in range(101):
|
81 |
+
maskrate = maskperc / float(100)
|
82 |
+
u = int(np.searchsorted(u2mean, maskrate))
|
83 |
+
mean2u.append(u)
|
84 |
+
|
85 |
+
return mean2u
|
86 |
+
|
87 |
+
def _non_spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
|
88 |
+
""" Construct a boolean mask of shape [B, T, 1], with masking rates defined by mask_probs.
|
89 |
+
The masked tokens are singletons, placed uniformly at random.
|
90 |
+
Args:
|
91 |
+
mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
|
92 |
+
B (int): Batch size.
|
93 |
+
T (int): Sequence length.
|
94 |
+
device (torch.device): device of the output tensor
|
95 |
+
Returns:
|
96 |
+
(torch.Tensor): A mask of shape [B, T]
|
97 |
+
"""
|
98 |
+
num_token_masked = (T * mask_probs).round().clamp(min=1)
|
99 |
+
batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1)
|
100 |
+
return batch_randperm < rearrange(num_token_masked, 'b -> b 1')
|
101 |
+
|
102 |
+
def _spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
|
103 |
+
""" Construct a spans mask with masking rates defined by mask_probs,
|
104 |
+
where the atomic span length ( > 1 ) is defined by cfg.masking.span_len.
|
105 |
+
Args:
|
106 |
+
mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
|
107 |
+
B (int): Batch size.
|
108 |
+
T (int): Sequence length.
|
109 |
+
device (torch.device): device of the output tensor
|
110 |
+
Returns:
|
111 |
+
(torch.Tensor): A spans mask of shape [B, T]
|
112 |
+
"""
|
113 |
+
rounded_probs = torch.round(100 * mask_probs).long()
|
114 |
+
k = self.mean_maskrate_to_u[rounded_probs].clamp(min=1) # k is the number of span starts
|
115 |
+
|
116 |
+
# sample random span starts
|
117 |
+
batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1)
|
118 |
+
mask = batch_randperm < rearrange(k, 'b -> b 1')
|
119 |
+
B, T = mask.shape
|
120 |
+
shifted_mask = mask.clone()
|
121 |
+
for _ in range(self.cfg.masking.span_len - 1):
|
122 |
+
shifted_mask = torch.concat((torch.full((B, 1), False, device=device), shifted_mask[:, :-1]), dim=1)
|
123 |
+
mask = torch.logical_or(mask, shifted_mask)
|
124 |
+
|
125 |
+
return mask
|
126 |
+
|
127 |
+
def _get_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
|
128 |
+
""" Construct a boolean mask with masking rates defined by mask_probs, and atomic
|
129 |
+
span length defined by cfg.masking.span_len.
|
130 |
+
Args:
|
131 |
+
mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
|
132 |
+
B (int): Batch size.
|
133 |
+
T (int): Sequence length.
|
134 |
+
device (torch.device): device of the output tensor
|
135 |
+
Returns:
|
136 |
+
(torch.Tensor): A boolean tensor of shape [B, T]
|
137 |
+
"""
|
138 |
+
if self.cfg.masking.span_len <= 1:
|
139 |
+
return self._non_spans_mask(mask_probs, B, T, device)
|
140 |
+
|
141 |
+
return self._spans_mask(mask_probs, B, T, device)
|
142 |
+
|
143 |
+
def _compute_cross_entropy_magnet(self, logits: torch.Tensor,
|
144 |
+
targets: torch.Tensor, mask: torch.Tensor, stage: torch.Tensor) -> torch.Tensor:
|
145 |
+
""" Compute cross entropy between multi-codebook targets and model's logits.
|
146 |
+
The cross entropy is computed only on a specific codebook, defined by the stage argument.
|
147 |
+
Valid timesteps for each codebook are pulled from the mask, where invalid
|
148 |
+
timesteps are set to 0.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
|
152 |
+
targets (torch.Tensor): Target codes, of shape [B, K, T].
|
153 |
+
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
|
154 |
+
stage (torch.Tensor): The codebook (idx) that is being optimized, as a scalar tensor.
|
155 |
+
Returns:
|
156 |
+
ce (torch.Tensor): Cross entropy of the codebook that is being optimized.
|
157 |
+
"""
|
158 |
+
assert logits.shape[:-1] == targets.shape
|
159 |
+
assert mask.shape == targets.shape
|
160 |
+
ce = torch.zeros([], device=targets.device)
|
161 |
+
logits_k = logits[:, stage, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
|
162 |
+
targets_k = targets[:, stage, ...].contiguous().view(-1) # [B x T]
|
163 |
+
mask_k = mask[:, stage, ...].contiguous().view(-1) # [B x T]
|
164 |
+
|
165 |
+
IGNORE_IDX = -1
|
166 |
+
targets_k[~mask_k] = IGNORE_IDX
|
167 |
+
q_ce = F.cross_entropy(logits_k, targets_k, ignore_index=IGNORE_IDX)
|
168 |
+
|
169 |
+
ce += q_ce
|
170 |
+
return ce
|
171 |
+
|
172 |
+
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
|
173 |
+
"""Perform one training or valid step on a given batch."""
|
174 |
+
check_synchronization_points = idx == 1 and self.device == 'cuda'
|
175 |
+
|
176 |
+
condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
|
177 |
+
batch, check_synchronization_points)
|
178 |
+
|
179 |
+
self.deadlock_detect.update('tokens_and_conditions')
|
180 |
+
|
181 |
+
if check_synchronization_points:
|
182 |
+
torch.cuda.set_sync_debug_mode('warn')
|
183 |
+
|
184 |
+
B, K, T = audio_tokens.shape
|
185 |
+
device = self.device
|
186 |
+
|
187 |
+
# Choose the stage (codebook idx) for update, uniformly at random.
|
188 |
+
stage_ = random.randint(0, K - 1)
|
189 |
+
stage = torch.full((1, ), stage_, device=device)
|
190 |
+
|
191 |
+
# masking
|
192 |
+
rand_time = torch.zeros((B,), device=device).float().uniform_(0, 1)
|
193 |
+
rand_mask_probs = torch.cos(rand_time * math.pi * 0.5)
|
194 |
+
|
195 |
+
# stage mask
|
196 |
+
stage_mask = self._get_mask(rand_mask_probs, B, T, device) # [B, T]
|
197 |
+
stage_mask = stage_mask.unsqueeze(1) # [B, 1, T]
|
198 |
+
|
199 |
+
# Keep all preceding codebooks.
|
200 |
+
mask = torch.full((B, K, T), False, device=device)
|
201 |
+
mask[:, stage, :] = stage_mask
|
202 |
+
|
203 |
+
# Mask all codebooks larger than stage_
|
204 |
+
mask_id = self.model.special_token_id
|
205 |
+
mask[:, (stage_+1):, :] = torch.full((B, K - stage_ - 1, T), True, device=device)
|
206 |
+
input_tokens = torch.where(mask, mask_id, audio_tokens)
|
207 |
+
|
208 |
+
# Take loss only on the chosen stage, and only on the masked tokens.
|
209 |
+
loss_mask = torch.full((B, K, T), False, device=device)
|
210 |
+
loss_mask[:, stage, :] = stage_mask
|
211 |
+
|
212 |
+
with self.autocast:
|
213 |
+
model_output = self.model.compute_predictions(input_tokens, [], condition_tensors, stage=stage_)
|
214 |
+
logits = model_output.logits
|
215 |
+
loss_mask &= padding_mask
|
216 |
+
ce = self._compute_cross_entropy_magnet(logits, audio_tokens, loss_mask, stage)
|
217 |
+
loss = ce
|
218 |
+
self.deadlock_detect.update('loss')
|
219 |
+
|
220 |
+
if check_synchronization_points:
|
221 |
+
torch.cuda.set_sync_debug_mode('default')
|
222 |
+
|
223 |
+
if self.is_training:
|
224 |
+
metrics['lr'] = self.optimizer.param_groups[0]['lr']
|
225 |
+
if self.scaler is not None:
|
226 |
+
loss = self.scaler.scale(loss)
|
227 |
+
self.deadlock_detect.update('scale')
|
228 |
+
if self.cfg.fsdp.use:
|
229 |
+
loss.backward()
|
230 |
+
flashy.distrib.average_tensors(self.model.buffers())
|
231 |
+
elif self.cfg.optim.eager_sync:
|
232 |
+
with flashy.distrib.eager_sync_model(self.model):
|
233 |
+
loss.backward()
|
234 |
+
else:
|
235 |
+
# this should always be slower but can be useful
|
236 |
+
# for weird use cases like multiple backwards.
|
237 |
+
loss.backward()
|
238 |
+
flashy.distrib.sync_model(self.model)
|
239 |
+
self.deadlock_detect.update('backward')
|
240 |
+
|
241 |
+
if self.scaler is not None:
|
242 |
+
self.scaler.unscale_(self.optimizer)
|
243 |
+
if self.cfg.optim.max_norm:
|
244 |
+
if self.cfg.fsdp.use:
|
245 |
+
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
|
246 |
+
else:
|
247 |
+
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
|
248 |
+
self.model.parameters(), self.cfg.optim.max_norm
|
249 |
+
)
|
250 |
+
if self.scaler is None:
|
251 |
+
self.optimizer.step()
|
252 |
+
else:
|
253 |
+
self.scaler.step(self.optimizer)
|
254 |
+
self.scaler.update()
|
255 |
+
if self.lr_scheduler:
|
256 |
+
self.lr_scheduler.step()
|
257 |
+
self.optimizer.zero_grad()
|
258 |
+
self.deadlock_detect.update('optim')
|
259 |
+
if self.scaler is not None:
|
260 |
+
scale = self.scaler.get_scale()
|
261 |
+
metrics['grad_scale'] = scale
|
262 |
+
if not loss.isfinite().all():
|
263 |
+
raise RuntimeError("Model probably diverged.")
|
264 |
+
|
265 |
+
metrics['ce'] = ce
|
266 |
+
metrics['ppl'] = torch.exp(ce)
|
267 |
+
|
268 |
+
return metrics
|
269 |
+
|
270 |
+
|
271 |
+
class AudioMagnetSolver(MagnetSolver):
|
272 |
+
"""Solver for audio-MAGNeT. A MAGNeT model for sound generation.
|
273 |
+
|
274 |
+
More information can be found in the MAGNeT model card.
|
275 |
+
"""
|
276 |
+
DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
|
audiocraft/solvers/musicgen.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import time
|
9 |
+
import typing as tp
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import flashy
|
13 |
+
import math
|
14 |
+
import omegaconf
|
15 |
+
import torch
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
from . import base, builders
|
19 |
+
from .compression import CompressionSolver
|
20 |
+
from .. import metrics as eval_metrics
|
21 |
+
from .. import models
|
22 |
+
from ..data.audio_dataset import AudioDataset
|
23 |
+
from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
|
24 |
+
from ..data.audio_utils import normalize_audio
|
25 |
+
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
|
26 |
+
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
|
27 |
+
from ..utils.samples.manager import SampleManager
|
28 |
+
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash
|
29 |
+
|
30 |
+
|
31 |
+
class MusicGenSolver(base.StandardSolver):
|
32 |
+
"""Solver for MusicGen training task.
|
33 |
+
|
34 |
+
Used in: https://arxiv.org/abs/2306.05284
|
35 |
+
"""
|
36 |
+
DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
|
37 |
+
|
38 |
+
def __init__(self, cfg: omegaconf.DictConfig):
|
39 |
+
super().__init__(cfg)
|
40 |
+
# easier access to sampling parameters
|
41 |
+
self.generation_params = {
|
42 |
+
'use_sampling': self.cfg.generate.lm.use_sampling,
|
43 |
+
'temp': self.cfg.generate.lm.temp,
|
44 |
+
'top_k': self.cfg.generate.lm.top_k,
|
45 |
+
'top_p': self.cfg.generate.lm.top_p,
|
46 |
+
}
|
47 |
+
self._best_metric_name: tp.Optional[str] = 'ce'
|
48 |
+
|
49 |
+
self._cached_batch_writer = None
|
50 |
+
self._cached_batch_loader = None
|
51 |
+
if cfg.cache.path:
|
52 |
+
if cfg.cache.write:
|
53 |
+
self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
|
54 |
+
if self.cfg.cache.write_num_shards:
|
55 |
+
self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
|
56 |
+
self._best_metric_name = None
|
57 |
+
else:
|
58 |
+
self._cached_batch_loader = CachedBatchLoader(
|
59 |
+
Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
|
60 |
+
min_length=self.cfg.optim.updates_per_epoch or 1)
|
61 |
+
self.dataloaders['original_train'] = self.dataloaders['train']
|
62 |
+
self.dataloaders['train'] = self._cached_batch_loader # type: ignore
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
|
66 |
+
device: tp.Optional[str] = None, autocast: bool = True,
|
67 |
+
batch_size: tp.Optional[int] = None,
|
68 |
+
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
|
69 |
+
**kwargs):
|
70 |
+
"""Mostly a convenience function around magma.train.get_solver_from_sig,
|
71 |
+
populating all the proper param, deactivating EMA, FSDP, loading the best state,
|
72 |
+
basically all you need to get a solver ready to "play" with in single GPU mode
|
73 |
+
and with minimal memory overhead.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
sig (str): signature to load.
|
77 |
+
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
|
78 |
+
device (str or None): potential device, as a string, i.e. 'cuda'.
|
79 |
+
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
|
80 |
+
"""
|
81 |
+
from audiocraft import train
|
82 |
+
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
|
83 |
+
our_override_cfg['autocast'] = autocast
|
84 |
+
if dtype is not None:
|
85 |
+
our_override_cfg['dtype'] = dtype
|
86 |
+
if device is not None:
|
87 |
+
our_override_cfg['device'] = device
|
88 |
+
if batch_size is not None:
|
89 |
+
our_override_cfg['dataset'] = {'batch_size': batch_size}
|
90 |
+
if override_cfg is None:
|
91 |
+
override_cfg = {}
|
92 |
+
override_cfg = omegaconf.OmegaConf.merge(
|
93 |
+
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
|
94 |
+
solver = train.get_solver_from_sig(
|
95 |
+
sig, override_cfg=override_cfg,
|
96 |
+
load_best=True, disable_fsdp=True,
|
97 |
+
ignore_state_keys=['optimizer', 'ema'], **kwargs)
|
98 |
+
solver.model.eval()
|
99 |
+
return solver
|
100 |
+
|
101 |
+
def get_formatter(self, stage_name: str) -> flashy.Formatter:
|
102 |
+
return flashy.Formatter({
|
103 |
+
'lr': '.2E',
|
104 |
+
'ce': '.3f',
|
105 |
+
'ppl': '.3f',
|
106 |
+
'grad_norm': '.3E',
|
107 |
+
}, exclude_keys=['ce_q*', 'ppl_q*'])
|
108 |
+
|
109 |
+
@property
|
110 |
+
def best_metric_name(self) -> tp.Optional[str]:
|
111 |
+
return self._best_metric_name
|
112 |
+
|
113 |
+
def build_model(self) -> None:
|
114 |
+
"""Instantiate models and optimizer."""
|
115 |
+
# we can potentially not use all quantizers with which the EnCodec model was trained
|
116 |
+
# (e.g. we trained the model with quantizers dropout)
|
117 |
+
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
|
118 |
+
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
|
119 |
+
assert self.compression_model.sample_rate == self.cfg.sample_rate, (
|
120 |
+
f"Compression model sample rate is {self.compression_model.sample_rate} but "
|
121 |
+
f"Solver sample rate is {self.cfg.sample_rate}."
|
122 |
+
)
|
123 |
+
# ensure we have matching configuration between LM and compression model
|
124 |
+
assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
|
125 |
+
"Cardinalities of the LM and compression model don't match: ",
|
126 |
+
f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
|
127 |
+
f"compression model cardinality is {self.compression_model.cardinality}"
|
128 |
+
)
|
129 |
+
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
|
130 |
+
"Numbers of codebooks of the LM and compression models don't match: ",
|
131 |
+
f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
|
132 |
+
f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
|
133 |
+
)
|
134 |
+
self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
|
135 |
+
self.compression_model.num_codebooks, self.compression_model.cardinality,
|
136 |
+
self.compression_model.frame_rate)
|
137 |
+
# instantiate LM model
|
138 |
+
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
|
139 |
+
if self.cfg.fsdp.use:
|
140 |
+
assert not self.cfg.autocast, "Cannot use autocast with fsdp"
|
141 |
+
self.model = self.wrap_with_fsdp(self.model)
|
142 |
+
self.register_ema('model')
|
143 |
+
# initialize optimization
|
144 |
+
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
|
145 |
+
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
|
146 |
+
self.register_stateful('model', 'optimizer', 'lr_scheduler')
|
147 |
+
self.register_best_state('model')
|
148 |
+
self.autocast_dtype = {
|
149 |
+
'float16': torch.float16, 'bfloat16': torch.bfloat16
|
150 |
+
}[self.cfg.autocast_dtype]
|
151 |
+
self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
|
152 |
+
if self.cfg.fsdp.use:
|
153 |
+
need_scaler = self.cfg.fsdp.param_dtype == 'float16'
|
154 |
+
else:
|
155 |
+
need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
|
156 |
+
if need_scaler:
|
157 |
+
if self.cfg.fsdp.use:
|
158 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
159 |
+
self.scaler = ShardedGradScaler() # type: ignore
|
160 |
+
else:
|
161 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
162 |
+
self.register_stateful('scaler')
|
163 |
+
|
164 |
+
def build_dataloaders(self) -> None:
|
165 |
+
"""Instantiate audio dataloaders for each stage."""
|
166 |
+
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
|
167 |
+
|
168 |
+
def show(self) -> None:
|
169 |
+
"""Show the compression model and LM model."""
|
170 |
+
self.logger.info("Compression model:")
|
171 |
+
self.log_model_summary(self.compression_model)
|
172 |
+
self.logger.info("LM model:")
|
173 |
+
self.log_model_summary(self.model)
|
174 |
+
|
175 |
+
def load_state_dict(self, state: dict) -> None:
|
176 |
+
if 'condition_provider' in state:
|
177 |
+
model_state = state['model']
|
178 |
+
condition_provider_state = state.pop('condition_provider')
|
179 |
+
prefix = 'condition_provider.'
|
180 |
+
for key, value in condition_provider_state.items():
|
181 |
+
key = prefix + key
|
182 |
+
assert key not in model_state
|
183 |
+
model_state[key] = value
|
184 |
+
if 'compression_model' in state:
|
185 |
+
# We used to store the `compression_model` state in the checkpoint, however
|
186 |
+
# this is in general not needed, as the compression model should always be readable
|
187 |
+
# from the original `cfg.compression_model_checkpoint` location.
|
188 |
+
compression_model_state = state.pop('compression_model')
|
189 |
+
before_hash = model_hash(self.compression_model)
|
190 |
+
self.compression_model.load_state_dict(compression_model_state)
|
191 |
+
after_hash = model_hash(self.compression_model)
|
192 |
+
if before_hash != after_hash:
|
193 |
+
raise RuntimeError(
|
194 |
+
"The compression model state inside the checkpoint is different"
|
195 |
+
" from the one obtained from compression_model_checkpoint..."
|
196 |
+
"We do not support altering the compression model inside the LM "
|
197 |
+
"checkpoint as parts of the code, in particular for running eval post-training "
|
198 |
+
"will use the compression_model_checkpoint as the source of truth.")
|
199 |
+
|
200 |
+
super().load_state_dict(state)
|
201 |
+
|
202 |
+
def load_from_pretrained(self, name: str):
|
203 |
+
# TODO: support native HF versions of MusicGen.
|
204 |
+
lm_pkg = models.loaders.load_lm_model_ckpt(name)
|
205 |
+
state: dict = {
|
206 |
+
'best_state': {
|
207 |
+
'model': lm_pkg['best_state'],
|
208 |
+
},
|
209 |
+
}
|
210 |
+
return state
|
211 |
+
|
212 |
+
def _compute_cross_entropy(
|
213 |
+
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
|
214 |
+
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
|
215 |
+
"""Compute cross entropy between multi-codebook targets and model's logits.
|
216 |
+
The cross entropy is computed per codebook to provide codebook-level cross entropy.
|
217 |
+
Valid timesteps for each of the codebook are pulled from the mask, where invalid
|
218 |
+
timesteps are set to 0.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
|
222 |
+
targets (torch.Tensor): Target codes, of shape [B, K, T].
|
223 |
+
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
|
224 |
+
Returns:
|
225 |
+
ce (torch.Tensor): Cross entropy averaged over the codebooks
|
226 |
+
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
|
227 |
+
"""
|
228 |
+
B, K, T = targets.shape
|
229 |
+
assert logits.shape[:-1] == targets.shape
|
230 |
+
assert mask.shape == targets.shape
|
231 |
+
ce = torch.zeros([], device=targets.device)
|
232 |
+
ce_per_codebook: tp.List[torch.Tensor] = []
|
233 |
+
for k in range(K):
|
234 |
+
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
|
235 |
+
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
|
236 |
+
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
|
237 |
+
ce_targets = targets_k[mask_k]
|
238 |
+
ce_logits = logits_k[mask_k]
|
239 |
+
q_ce = F.cross_entropy(ce_logits, ce_targets)
|
240 |
+
ce += q_ce
|
241 |
+
ce_per_codebook.append(q_ce.detach())
|
242 |
+
# average cross entropy across codebooks
|
243 |
+
ce = ce / K
|
244 |
+
return ce, ce_per_codebook
|
245 |
+
|
246 |
+
def _prepare_tokens_and_attributes(
|
247 |
+
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
|
248 |
+
check_synchronization_points: bool = False
|
249 |
+
) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
|
250 |
+
"""Prepare input batchs for language model training.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
|
254 |
+
and corresponding metadata as SegmentWithAttributes (with B items).
|
255 |
+
check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
|
256 |
+
Returns:
|
257 |
+
Condition tensors (dict[str, any]): Preprocessed condition attributes.
|
258 |
+
Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
|
259 |
+
with B the batch size, K the number of codebooks, T_s the token timesteps.
|
260 |
+
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
|
261 |
+
"""
|
262 |
+
if self.model.training:
|
263 |
+
warnings.warn(
|
264 |
+
"Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
|
265 |
+
"This is inconsistent with how model were trained in the MusicGen paper. We removed the "
|
266 |
+
"`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
|
267 |
+
"Really sorry about that.")
|
268 |
+
if self._cached_batch_loader is None or self.current_stage != "train":
|
269 |
+
audio, infos = batch
|
270 |
+
audio = audio.to(self.device)
|
271 |
+
audio_tokens = None
|
272 |
+
assert audio.size(0) == len(infos), (
|
273 |
+
f"Mismatch between number of items in audio batch ({audio.size(0)})",
|
274 |
+
f" and in metadata ({len(infos)})"
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
audio = None
|
278 |
+
# In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
|
279 |
+
infos, = batch # type: ignore
|
280 |
+
assert all([isinstance(info, AudioInfo) for info in infos])
|
281 |
+
assert all([info.audio_tokens is not None for info in infos]) # type: ignore
|
282 |
+
audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
|
283 |
+
audio_tokens = audio_tokens.long()
|
284 |
+
for info in infos:
|
285 |
+
if isinstance(info, MusicInfo):
|
286 |
+
# Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
|
287 |
+
# then you must be using the chroma cache! otherwise the code will try
|
288 |
+
# to use this segment and fail (by that I mean you will see NaN everywhere).
|
289 |
+
info.self_wav = WavCondition(
|
290 |
+
torch.full([1, info.channels, info.total_frames], float('NaN')),
|
291 |
+
length=torch.tensor([info.n_frames]),
|
292 |
+
sample_rate=[info.sample_rate],
|
293 |
+
path=[info.meta.path],
|
294 |
+
seek_time=[info.seek_time])
|
295 |
+
dataset = get_dataset_from_loader(self.dataloaders['original_train'])
|
296 |
+
assert isinstance(dataset, MusicDataset), type(dataset)
|
297 |
+
if dataset.paraphraser is not None and info.description is not None:
|
298 |
+
# Hackingly reapplying paraphraser when using cache.
|
299 |
+
info.description = dataset.paraphraser.sample_paraphrase(
|
300 |
+
info.meta.path, info.description)
|
301 |
+
# prepare attributes
|
302 |
+
attributes = [info.to_condition_attributes() for info in infos]
|
303 |
+
attributes = self.model.cfg_dropout(attributes)
|
304 |
+
attributes = self.model.att_dropout(attributes)
|
305 |
+
tokenized = self.model.condition_provider.tokenize(attributes)
|
306 |
+
|
307 |
+
# Now we should be synchronization free.
|
308 |
+
if self.device == "cuda" and check_synchronization_points:
|
309 |
+
torch.cuda.set_sync_debug_mode("warn")
|
310 |
+
|
311 |
+
if audio_tokens is None:
|
312 |
+
with torch.no_grad():
|
313 |
+
audio_tokens, scale = self.compression_model.encode(audio)
|
314 |
+
assert scale is None, "Scaled compression model not supported with LM."
|
315 |
+
|
316 |
+
with self.autocast:
|
317 |
+
condition_tensors = self.model.condition_provider(tokenized)
|
318 |
+
|
319 |
+
# create a padding mask to hold valid vs invalid positions
|
320 |
+
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
|
321 |
+
# replace encodec tokens from padded audio with special_token_id
|
322 |
+
if self.cfg.tokens.padding_with_special_token:
|
323 |
+
audio_tokens = audio_tokens.clone()
|
324 |
+
padding_mask = padding_mask.clone()
|
325 |
+
token_sample_rate = self.compression_model.frame_rate
|
326 |
+
B, K, T_s = audio_tokens.shape
|
327 |
+
for i in range(B):
|
328 |
+
n_samples = infos[i].n_frames
|
329 |
+
audio_sample_rate = infos[i].sample_rate
|
330 |
+
# take the last token generated from actual audio frames (non-padded audio)
|
331 |
+
valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
|
332 |
+
audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
|
333 |
+
padding_mask[i, :, valid_tokens:] = 0
|
334 |
+
|
335 |
+
if self.device == "cuda" and check_synchronization_points:
|
336 |
+
torch.cuda.set_sync_debug_mode("default")
|
337 |
+
|
338 |
+
if self._cached_batch_writer is not None and self.current_stage == 'train':
|
339 |
+
assert self._cached_batch_loader is None
|
340 |
+
assert audio_tokens is not None
|
341 |
+
for info, one_audio_tokens in zip(infos, audio_tokens):
|
342 |
+
assert isinstance(info, AudioInfo)
|
343 |
+
if isinstance(info, MusicInfo):
|
344 |
+
assert not info.joint_embed, "joint_embed and cache not supported yet."
|
345 |
+
info.self_wav = None
|
346 |
+
assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
|
347 |
+
info.audio_tokens = one_audio_tokens.short().cpu()
|
348 |
+
self._cached_batch_writer.save(infos)
|
349 |
+
|
350 |
+
return condition_tensors, audio_tokens, padding_mask
|
351 |
+
|
352 |
+
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
|
353 |
+
"""Perform one training or valid step on a given batch."""
|
354 |
+
check_synchronization_points = idx == 1 and self.device == 'cuda'
|
355 |
+
|
356 |
+
condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
|
357 |
+
batch, check_synchronization_points)
|
358 |
+
|
359 |
+
self.deadlock_detect.update('tokens_and_conditions')
|
360 |
+
|
361 |
+
if check_synchronization_points:
|
362 |
+
torch.cuda.set_sync_debug_mode('warn')
|
363 |
+
|
364 |
+
with self.autocast:
|
365 |
+
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
|
366 |
+
logits = model_output.logits
|
367 |
+
mask = padding_mask & model_output.mask
|
368 |
+
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
|
369 |
+
loss = ce
|
370 |
+
self.deadlock_detect.update('loss')
|
371 |
+
|
372 |
+
if check_synchronization_points:
|
373 |
+
torch.cuda.set_sync_debug_mode('default')
|
374 |
+
|
375 |
+
if self.is_training:
|
376 |
+
metrics['lr'] = self.optimizer.param_groups[0]['lr']
|
377 |
+
if self.scaler is not None:
|
378 |
+
loss = self.scaler.scale(loss)
|
379 |
+
self.deadlock_detect.update('scale')
|
380 |
+
if self.cfg.fsdp.use:
|
381 |
+
loss.backward()
|
382 |
+
flashy.distrib.average_tensors(self.model.buffers())
|
383 |
+
elif self.cfg.optim.eager_sync:
|
384 |
+
with flashy.distrib.eager_sync_model(self.model):
|
385 |
+
loss.backward()
|
386 |
+
else:
|
387 |
+
# this should always be slower but can be useful
|
388 |
+
# for weird use cases like multiple backwards.
|
389 |
+
loss.backward()
|
390 |
+
flashy.distrib.sync_model(self.model)
|
391 |
+
self.deadlock_detect.update('backward')
|
392 |
+
|
393 |
+
if self.scaler is not None:
|
394 |
+
self.scaler.unscale_(self.optimizer)
|
395 |
+
if self.cfg.optim.max_norm:
|
396 |
+
if self.cfg.fsdp.use:
|
397 |
+
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
|
398 |
+
else:
|
399 |
+
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
|
400 |
+
self.model.parameters(), self.cfg.optim.max_norm
|
401 |
+
)
|
402 |
+
if self.scaler is None:
|
403 |
+
self.optimizer.step()
|
404 |
+
else:
|
405 |
+
self.scaler.step(self.optimizer)
|
406 |
+
self.scaler.update()
|
407 |
+
if self.lr_scheduler:
|
408 |
+
self.lr_scheduler.step()
|
409 |
+
self.optimizer.zero_grad()
|
410 |
+
self.deadlock_detect.update('optim')
|
411 |
+
if self.scaler is not None:
|
412 |
+
scale = self.scaler.get_scale()
|
413 |
+
metrics['grad_scale'] = scale
|
414 |
+
if not loss.isfinite().all():
|
415 |
+
raise RuntimeError("Model probably diverged.")
|
416 |
+
|
417 |
+
metrics['ce'] = ce
|
418 |
+
metrics['ppl'] = torch.exp(ce)
|
419 |
+
for k, ce_q in enumerate(ce_per_codebook):
|
420 |
+
metrics[f'ce_q{k + 1}'] = ce_q
|
421 |
+
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
|
422 |
+
|
423 |
+
return metrics
|
424 |
+
|
425 |
+
@torch.no_grad()
|
426 |
+
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
|
427 |
+
gen_duration: float, prompt_duration: tp.Optional[float] = None,
|
428 |
+
remove_prompt: bool = False,
|
429 |
+
**generation_params) -> dict:
|
430 |
+
"""Run generate step on a batch of optional audio tensor and corresponding attributes.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
|
434 |
+
use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
|
435 |
+
gen_duration (float): Target audio duration for the generation.
|
436 |
+
prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
|
437 |
+
remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
|
438 |
+
generation_params: Additional generation parameters.
|
439 |
+
Returns:
|
440 |
+
gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
|
441 |
+
and the prompt along with additional information.
|
442 |
+
"""
|
443 |
+
bench_start = time.time()
|
444 |
+
audio, meta = batch
|
445 |
+
assert audio.size(0) == len(meta), (
|
446 |
+
f"Mismatch between number of items in audio batch ({audio.size(0)})",
|
447 |
+
f" and in metadata ({len(meta)})"
|
448 |
+
)
|
449 |
+
# prepare attributes
|
450 |
+
attributes = [x.to_condition_attributes() for x in meta]
|
451 |
+
# TODO: Add dropout for chroma?
|
452 |
+
|
453 |
+
# prepare audio prompt
|
454 |
+
if prompt_duration is None:
|
455 |
+
prompt_audio = None
|
456 |
+
else:
|
457 |
+
assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
|
458 |
+
prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
|
459 |
+
prompt_audio = audio[..., :prompt_audio_frames]
|
460 |
+
|
461 |
+
# get audio tokens from compression model
|
462 |
+
if prompt_audio is None or prompt_audio.nelement() == 0:
|
463 |
+
num_samples = len(attributes)
|
464 |
+
prompt_tokens = None
|
465 |
+
else:
|
466 |
+
num_samples = None
|
467 |
+
prompt_audio = prompt_audio.to(self.device)
|
468 |
+
prompt_tokens, scale = self.compression_model.encode(prompt_audio)
|
469 |
+
assert scale is None, "Compression model in MusicGen should not require rescaling."
|
470 |
+
|
471 |
+
# generate by sampling from the LM
|
472 |
+
with self.autocast:
|
473 |
+
total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
|
474 |
+
gen_tokens = self.model.generate(
|
475 |
+
prompt_tokens, attributes, max_gen_len=total_gen_len,
|
476 |
+
num_samples=num_samples, **self.generation_params)
|
477 |
+
|
478 |
+
# generate audio from tokens
|
479 |
+
assert gen_tokens.dim() == 3
|
480 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
481 |
+
|
482 |
+
bench_end = time.time()
|
483 |
+
gen_outputs = {
|
484 |
+
'rtf': (bench_end - bench_start) / gen_duration,
|
485 |
+
'ref_audio': audio,
|
486 |
+
'gen_audio': gen_audio,
|
487 |
+
'gen_tokens': gen_tokens,
|
488 |
+
'prompt_audio': prompt_audio,
|
489 |
+
'prompt_tokens': prompt_tokens,
|
490 |
+
}
|
491 |
+
return gen_outputs
|
492 |
+
|
493 |
+
def generate_audio(self) -> dict:
|
494 |
+
"""Audio generation stage."""
|
495 |
+
generate_stage_name = f'{self.current_stage}'
|
496 |
+
sample_manager = SampleManager(self.xp)
|
497 |
+
self.logger.info(f"Generating samples in {sample_manager.base_folder}")
|
498 |
+
loader = self.dataloaders['generate']
|
499 |
+
updates = len(loader)
|
500 |
+
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
|
501 |
+
|
502 |
+
dataset = get_dataset_from_loader(loader)
|
503 |
+
dataset_duration = dataset.segment_duration
|
504 |
+
assert dataset_duration is not None
|
505 |
+
assert isinstance(dataset, AudioDataset)
|
506 |
+
target_duration = self.cfg.generate.lm.gen_duration
|
507 |
+
prompt_duration = self.cfg.generate.lm.prompt_duration
|
508 |
+
if target_duration is None:
|
509 |
+
target_duration = dataset_duration
|
510 |
+
if prompt_duration is None:
|
511 |
+
prompt_duration = dataset_duration / 4
|
512 |
+
assert prompt_duration < dataset_duration, (
|
513 |
+
f"Specified prompt duration ({prompt_duration}s) is longer",
|
514 |
+
f" than reference audio duration ({dataset_duration}s)"
|
515 |
+
)
|
516 |
+
|
517 |
+
def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
|
518 |
+
hydrated_conditions = []
|
519 |
+
for sample in [x.to_condition_attributes() for x in meta]:
|
520 |
+
cond_dict = {}
|
521 |
+
for cond_type in sample.__annotations__.keys():
|
522 |
+
for cond_key, cond_val in getattr(sample, cond_type).items():
|
523 |
+
if cond_key not in self.model.condition_provider.conditioners.keys():
|
524 |
+
continue
|
525 |
+
if is_jsonable(cond_val):
|
526 |
+
cond_dict[cond_key] = cond_val
|
527 |
+
elif isinstance(cond_val, WavCondition):
|
528 |
+
cond_dict[cond_key] = cond_val.path
|
529 |
+
elif isinstance(cond_val, JointEmbedCondition):
|
530 |
+
cond_dict[cond_key] = cond_val.text # only support text at inference for now
|
531 |
+
else:
|
532 |
+
# if we reached this point, it is not clear how to log the condition
|
533 |
+
# so we just log the type.
|
534 |
+
cond_dict[cond_key] = str(type(cond_val))
|
535 |
+
continue
|
536 |
+
hydrated_conditions.append(cond_dict)
|
537 |
+
return hydrated_conditions
|
538 |
+
|
539 |
+
metrics: dict = {}
|
540 |
+
average = flashy.averager()
|
541 |
+
for batch in lp:
|
542 |
+
audio, meta = batch
|
543 |
+
# metadata for sample manager
|
544 |
+
hydrated_conditions = get_hydrated_conditions(meta)
|
545 |
+
sample_generation_params = {
|
546 |
+
**{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
|
547 |
+
**self.generation_params
|
548 |
+
}
|
549 |
+
if self.cfg.generate.lm.unprompted_samples:
|
550 |
+
if self.cfg.generate.lm.gen_gt_samples:
|
551 |
+
# get the ground truth instead of generation
|
552 |
+
self.logger.warn(
|
553 |
+
"Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
|
554 |
+
gen_unprompted_audio = audio
|
555 |
+
rtf = 1.
|
556 |
+
else:
|
557 |
+
gen_unprompted_outputs = self.run_generate_step(
|
558 |
+
batch, gen_duration=target_duration, prompt_duration=None,
|
559 |
+
**self.generation_params)
|
560 |
+
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
|
561 |
+
rtf = gen_unprompted_outputs['rtf']
|
562 |
+
sample_manager.add_samples(
|
563 |
+
gen_unprompted_audio, self.epoch, hydrated_conditions,
|
564 |
+
ground_truth_wavs=audio, generation_args=sample_generation_params)
|
565 |
+
|
566 |
+
if self.cfg.generate.lm.prompted_samples:
|
567 |
+
gen_outputs = self.run_generate_step(
|
568 |
+
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
|
569 |
+
**self.generation_params)
|
570 |
+
gen_audio = gen_outputs['gen_audio'].cpu()
|
571 |
+
prompt_audio = gen_outputs['prompt_audio'].cpu()
|
572 |
+
sample_manager.add_samples(
|
573 |
+
gen_audio, self.epoch, hydrated_conditions,
|
574 |
+
prompt_wavs=prompt_audio, ground_truth_wavs=audio,
|
575 |
+
generation_args=sample_generation_params)
|
576 |
+
|
577 |
+
metrics['rtf'] = rtf
|
578 |
+
metrics = average(metrics)
|
579 |
+
|
580 |
+
flashy.distrib.barrier()
|
581 |
+
return metrics
|
582 |
+
|
583 |
+
def generate(self) -> dict:
|
584 |
+
"""Generate stage."""
|
585 |
+
self.model.eval()
|
586 |
+
with torch.no_grad():
|
587 |
+
return self.generate_audio()
|
588 |
+
|
589 |
+
def run_epoch(self):
|
590 |
+
if self.cfg.cache.write:
|
591 |
+
if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
|
592 |
+
return
|
593 |
+
super().run_epoch()
|
594 |
+
|
595 |
+
def train(self):
|
596 |
+
"""Train stage.
|
597 |
+
"""
|
598 |
+
if self._cached_batch_writer is not None:
|
599 |
+
self._cached_batch_writer.start_epoch(self.epoch)
|
600 |
+
if self._cached_batch_loader is None:
|
601 |
+
dataset = get_dataset_from_loader(self.dataloaders['train'])
|
602 |
+
assert isinstance(dataset, AudioDataset)
|
603 |
+
dataset.current_epoch = self.epoch
|
604 |
+
else:
|
605 |
+
self._cached_batch_loader.start_epoch(self.epoch)
|
606 |
+
return super().train()
|
607 |
+
|
608 |
+
def evaluate_audio_generation(self) -> dict:
|
609 |
+
"""Evaluate audio generation with off-the-shelf metrics."""
|
610 |
+
evaluate_stage_name = f'{self.current_stage}_generation'
|
611 |
+
# instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
|
612 |
+
fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
|
613 |
+
kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
|
614 |
+
text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
|
615 |
+
chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
|
616 |
+
should_run_eval = False
|
617 |
+
eval_chroma_wavs: tp.Optional[torch.Tensor] = None
|
618 |
+
if self.cfg.evaluate.metrics.fad:
|
619 |
+
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
|
620 |
+
should_run_eval = True
|
621 |
+
if self.cfg.evaluate.metrics.kld:
|
622 |
+
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
|
623 |
+
should_run_eval = True
|
624 |
+
if self.cfg.evaluate.metrics.text_consistency:
|
625 |
+
text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
|
626 |
+
should_run_eval = True
|
627 |
+
if self.cfg.evaluate.metrics.chroma_cosine:
|
628 |
+
chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
|
629 |
+
# if we have predefind wavs for chroma we should purge them for computing the cosine metric
|
630 |
+
has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
|
631 |
+
self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
|
632 |
+
if has_predefined_eval_chromas:
|
633 |
+
warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
|
634 |
+
'Resetting eval chromas to None for evaluation.')
|
635 |
+
eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
|
636 |
+
self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
|
637 |
+
should_run_eval = True
|
638 |
+
|
639 |
+
def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
|
640 |
+
audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
|
641 |
+
compressed_audio = self.compression_model.decode(audio_tokens, scale)
|
642 |
+
return compressed_audio[..., :audio.shape[-1]]
|
643 |
+
|
644 |
+
metrics: dict = {}
|
645 |
+
if should_run_eval:
|
646 |
+
loader = self.dataloaders['evaluate']
|
647 |
+
updates = len(loader)
|
648 |
+
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
|
649 |
+
average = flashy.averager()
|
650 |
+
dataset = get_dataset_from_loader(loader)
|
651 |
+
assert isinstance(dataset, AudioDataset)
|
652 |
+
self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
|
653 |
+
|
654 |
+
for idx, batch in enumerate(lp):
|
655 |
+
audio, meta = batch
|
656 |
+
assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
|
657 |
+
|
658 |
+
target_duration = audio.shape[-1] / self.cfg.sample_rate
|
659 |
+
if self.cfg.evaluate.fixed_generation_duration:
|
660 |
+
target_duration = self.cfg.evaluate.fixed_generation_duration
|
661 |
+
|
662 |
+
gen_outputs = self.run_generate_step(
|
663 |
+
batch, gen_duration=target_duration,
|
664 |
+
**self.generation_params
|
665 |
+
)
|
666 |
+
y_pred = gen_outputs['gen_audio'].detach()
|
667 |
+
y_pred = y_pred[..., :audio.shape[-1]]
|
668 |
+
|
669 |
+
normalize_kwargs = dict(self.cfg.generate.audio)
|
670 |
+
normalize_kwargs.pop('format', None)
|
671 |
+
y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
|
672 |
+
y = audio.cpu() # should already be on CPU but just in case
|
673 |
+
sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
|
674 |
+
sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
|
675 |
+
audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
|
676 |
+
|
677 |
+
if fad is not None:
|
678 |
+
if self.cfg.metrics.fad.use_gt:
|
679 |
+
y_pred = get_compressed_audio(y).cpu()
|
680 |
+
fad.update(y_pred, y, sizes, sample_rates, audio_stems)
|
681 |
+
if kldiv is not None:
|
682 |
+
if self.cfg.metrics.kld.use_gt:
|
683 |
+
y_pred = get_compressed_audio(y).cpu()
|
684 |
+
kldiv.update(y_pred, y, sizes, sample_rates)
|
685 |
+
if text_consistency is not None:
|
686 |
+
texts = [m.description for m in meta]
|
687 |
+
if self.cfg.metrics.text_consistency.use_gt:
|
688 |
+
y_pred = y
|
689 |
+
text_consistency.update(y_pred, texts, sizes, sample_rates)
|
690 |
+
if chroma_cosine is not None:
|
691 |
+
if self.cfg.metrics.chroma_cosine.use_gt:
|
692 |
+
y_pred = get_compressed_audio(y).cpu()
|
693 |
+
chroma_cosine.update(y_pred, y, sizes, sample_rates)
|
694 |
+
# restore chroma conditioner's eval chroma wavs
|
695 |
+
if eval_chroma_wavs is not None:
|
696 |
+
self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
|
697 |
+
|
698 |
+
flashy.distrib.barrier()
|
699 |
+
if fad is not None:
|
700 |
+
metrics['fad'] = fad.compute()
|
701 |
+
if kldiv is not None:
|
702 |
+
kld_metrics = kldiv.compute()
|
703 |
+
metrics.update(kld_metrics)
|
704 |
+
if text_consistency is not None:
|
705 |
+
metrics['text_consistency'] = text_consistency.compute()
|
706 |
+
if chroma_cosine is not None:
|
707 |
+
metrics['chroma_cosine'] = chroma_cosine.compute()
|
708 |
+
metrics = average(metrics)
|
709 |
+
metrics = flashy.distrib.average_metrics(metrics, len(loader))
|
710 |
+
|
711 |
+
return metrics
|
712 |
+
|
713 |
+
def evaluate(self) -> dict:
|
714 |
+
"""Evaluate stage."""
|
715 |
+
self.model.eval()
|
716 |
+
with torch.no_grad():
|
717 |
+
metrics: dict = {}
|
718 |
+
if self.cfg.evaluate.metrics.base:
|
719 |
+
metrics.update(self.common_train_valid('evaluate'))
|
720 |
+
gen_metrics = self.evaluate_audio_generation()
|
721 |
+
return {**metrics, **gen_metrics}
|