Spaces:
Sleeping
Sleeping
unpairedelectron07
commited on
Commit
•
70c420e
1
Parent(s):
eb6b37d
Upload 3 files
Browse files- audiocraft/environment.py +176 -0
- audiocraft/py.typed +0 -0
- audiocraft/train.py +163 -0
audiocraft/environment.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import re
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
|
19 |
+
from .utils.cluster import _guess_cluster_type
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AudioCraftEnvironment:
|
26 |
+
"""Environment configuration for teams and clusters.
|
27 |
+
|
28 |
+
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
+
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
+
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
+
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
+
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
+
|
34 |
+
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
+
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
+
|
37 |
+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
+
cannot be inferred automatically.
|
39 |
+
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
+
If not set, configuration is read from config/teams.yaml.
|
41 |
+
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
+
Cluster configuration are shared across teams to match compute allocation,
|
43 |
+
specify your cluster configuration in the configuration file under a key mapping
|
44 |
+
your team name.
|
45 |
+
"""
|
46 |
+
_instance = None
|
47 |
+
DEFAULT_TEAM = "default"
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Loads configuration."""
|
51 |
+
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
+
cluster_type = _guess_cluster_type()
|
53 |
+
cluster = os.getenv(
|
54 |
+
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
+
)
|
56 |
+
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
+
|
58 |
+
self.cluster: str = cluster
|
59 |
+
|
60 |
+
config_path = os.getenv(
|
61 |
+
"AUDIOCRAFT_CONFIG",
|
62 |
+
Path(__file__)
|
63 |
+
.parent.parent.joinpath("config/teams", self.team)
|
64 |
+
.with_suffix(".yaml"),
|
65 |
+
)
|
66 |
+
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
+
self._dataset_mappers = []
|
68 |
+
cluster_config = self._get_cluster_config()
|
69 |
+
if "dataset_mappers" in cluster_config:
|
70 |
+
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
+
regex = re.compile(pattern)
|
72 |
+
self._dataset_mappers.append((regex, repl))
|
73 |
+
|
74 |
+
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
+
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
+
return self.config[self.cluster]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def instance(cls):
|
80 |
+
if cls._instance is None:
|
81 |
+
cls._instance = cls()
|
82 |
+
return cls._instance
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reset(cls):
|
86 |
+
"""Clears the environment and forces a reload on next invocation."""
|
87 |
+
cls._instance = None
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def get_team(cls) -> str:
|
91 |
+
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
+
If not defined, defaults to "labs".
|
93 |
+
"""
|
94 |
+
return cls.instance().team
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def get_cluster(cls) -> str:
|
98 |
+
"""Gets the detected cluster.
|
99 |
+
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
+
"""
|
101 |
+
return cls.instance().cluster
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def get_dora_dir(cls) -> Path:
|
105 |
+
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
+
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
+
"""
|
108 |
+
cluster_config = cls.instance()._get_cluster_config()
|
109 |
+
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
+
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
+
return Path(dora_dir)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def get_reference_dir(cls) -> Path:
|
115 |
+
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
+
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
+
"""
|
118 |
+
cluster_config = cls.instance()._get_cluster_config()
|
119 |
+
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
+
"""Get the list of nodes to exclude for that cluster."""
|
124 |
+
cluster_config = cls.instance()._get_cluster_config()
|
125 |
+
return cluster_config.get("slurm_exclude")
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
+
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
+
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
+
"""
|
135 |
+
if not partition_types:
|
136 |
+
partition_types = ["global"]
|
137 |
+
|
138 |
+
cluster_config = cls.instance()._get_cluster_config()
|
139 |
+
partitions = [
|
140 |
+
cluster_config["partitions"][partition_type]
|
141 |
+
for partition_type in partition_types
|
142 |
+
]
|
143 |
+
return ",".join(partitions)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
+
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
path (str or Path): Path to resolve.
|
151 |
+
Returns:
|
152 |
+
Path: Resolved path.
|
153 |
+
"""
|
154 |
+
path = str(path)
|
155 |
+
|
156 |
+
if path.startswith("//reference"):
|
157 |
+
reference_dir = cls.get_reference_dir()
|
158 |
+
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
+
assert (
|
160 |
+
reference_dir.exists() and reference_dir.is_dir()
|
161 |
+
), f"Reference directory does not exist: {reference_dir}."
|
162 |
+
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
+
|
164 |
+
return Path(path)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
+
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
+
If no rules are defined, the path is returned as-is.
|
170 |
+
"""
|
171 |
+
instance = cls.instance()
|
172 |
+
|
173 |
+
for pattern, repl in instance._dataset_mappers:
|
174 |
+
path = pattern.sub(repl, path)
|
175 |
+
|
176 |
+
return path
|
audiocraft/py.typed
ADDED
File without changes
|
audiocraft/train.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Entry point for dora to launch solvers for running training loops.
|
9 |
+
See more info on how to use dora: https://github.com/facebookresearch/dora
|
10 |
+
"""
|
11 |
+
|
12 |
+
import logging
|
13 |
+
import multiprocessing
|
14 |
+
import os
|
15 |
+
from pathlib import Path
|
16 |
+
import sys
|
17 |
+
import typing as tp
|
18 |
+
|
19 |
+
from dora import git_save, hydra_main, XP
|
20 |
+
import flashy
|
21 |
+
import hydra
|
22 |
+
import omegaconf
|
23 |
+
|
24 |
+
from .environment import AudioCraftEnvironment
|
25 |
+
from .utils.cluster import get_slurm_parameters
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
def resolve_config_dset_paths(cfg):
|
31 |
+
"""Enable Dora to load manifest from git clone repository."""
|
32 |
+
# manifest files for the different splits
|
33 |
+
for key, value in cfg.datasource.items():
|
34 |
+
if isinstance(value, str):
|
35 |
+
cfg.datasource[key] = git_save.to_absolute_path(value)
|
36 |
+
|
37 |
+
|
38 |
+
def get_solver(cfg):
|
39 |
+
from . import solvers
|
40 |
+
# Convert batch size to batch size for each GPU
|
41 |
+
assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
|
42 |
+
cfg.dataset.batch_size //= flashy.distrib.world_size()
|
43 |
+
for split in ['train', 'valid', 'evaluate', 'generate']:
|
44 |
+
if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
|
45 |
+
assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
|
46 |
+
cfg.dataset[split].batch_size //= flashy.distrib.world_size()
|
47 |
+
resolve_config_dset_paths(cfg)
|
48 |
+
solver = solvers.get_solver(cfg)
|
49 |
+
return solver
|
50 |
+
|
51 |
+
|
52 |
+
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
|
53 |
+
restore: bool = True, load_best: bool = True,
|
54 |
+
ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
|
55 |
+
"""Given a XP, return the Solver object.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
xp (XP): Dora experiment for which to retrieve the solver.
|
59 |
+
override_cfg (dict or None): If not None, should be a dict used to
|
60 |
+
override some values in the config of `xp`. This will not impact
|
61 |
+
the XP signature or folder. The format is different
|
62 |
+
than the one used in Dora grids, nested keys should actually be nested dicts,
|
63 |
+
not flattened, e.g. `{'optim': {'batch_size': 32}}`.
|
64 |
+
restore (bool): If `True` (the default), restore state from the last checkpoint.
|
65 |
+
load_best (bool): If `True` (the default), load the best state from the checkpoint.
|
66 |
+
ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
|
67 |
+
disable_fsdp (bool): if True, disables FSDP entirely. This will
|
68 |
+
also automatically skip loading the EMA. For solver specific
|
69 |
+
state sources, like the optimizer, you might want to
|
70 |
+
use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
|
71 |
+
"""
|
72 |
+
logger.info(f"Loading solver from XP {xp.sig}. "
|
73 |
+
f"Overrides used: {xp.argv}")
|
74 |
+
cfg = xp.cfg
|
75 |
+
if override_cfg is not None:
|
76 |
+
cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
|
77 |
+
if disable_fsdp and cfg.fsdp.use:
|
78 |
+
cfg.fsdp.use = False
|
79 |
+
assert load_best is True
|
80 |
+
# ignoring some keys that were FSDP sharded like model, ema, and best_state.
|
81 |
+
# fsdp_best_state will be used in that case. When using a specific solver,
|
82 |
+
# one is responsible for adding the relevant keys, e.g. 'optimizer'.
|
83 |
+
# We could make something to automatically register those inside the solver, but that
|
84 |
+
# seem overkill at this point.
|
85 |
+
ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
|
86 |
+
|
87 |
+
try:
|
88 |
+
with xp.enter():
|
89 |
+
solver = get_solver(cfg)
|
90 |
+
if restore:
|
91 |
+
solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
|
92 |
+
return solver
|
93 |
+
finally:
|
94 |
+
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
95 |
+
|
96 |
+
|
97 |
+
def get_solver_from_sig(sig: str, *args, **kwargs):
|
98 |
+
"""Return Solver object from Dora signature, i.e. to play with it from a notebook.
|
99 |
+
See `get_solver_from_xp` for more information.
|
100 |
+
"""
|
101 |
+
xp = main.get_xp_from_sig(sig)
|
102 |
+
return get_solver_from_xp(xp, *args, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def init_seed_and_system(cfg):
|
106 |
+
import numpy as np
|
107 |
+
import torch
|
108 |
+
import random
|
109 |
+
from audiocraft.modules.transformer import set_efficient_attention_backend
|
110 |
+
|
111 |
+
multiprocessing.set_start_method(cfg.mp_start_method)
|
112 |
+
logger.debug('Setting mp start method to %s', cfg.mp_start_method)
|
113 |
+
random.seed(cfg.seed)
|
114 |
+
np.random.seed(cfg.seed)
|
115 |
+
# torch also initialize cuda seed if available
|
116 |
+
torch.manual_seed(cfg.seed)
|
117 |
+
torch.set_num_threads(cfg.num_threads)
|
118 |
+
os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
|
119 |
+
os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
|
120 |
+
logger.debug('Setting num threads to %d', cfg.num_threads)
|
121 |
+
set_efficient_attention_backend(cfg.efficient_attention_backend)
|
122 |
+
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
|
123 |
+
if 'SLURM_JOB_ID' in os.environ:
|
124 |
+
tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
|
125 |
+
if tmpdir.exists():
|
126 |
+
logger.info("Changing tmpdir to %s", tmpdir)
|
127 |
+
os.environ['TMPDIR'] = str(tmpdir)
|
128 |
+
|
129 |
+
|
130 |
+
@hydra_main(config_path='../config', config_name='config', version_base='1.1')
|
131 |
+
def main(cfg):
|
132 |
+
init_seed_and_system(cfg)
|
133 |
+
|
134 |
+
# Setup logging both to XP specific folder, and to stderr.
|
135 |
+
log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
|
136 |
+
flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
|
137 |
+
# Initialize distributed training, no need to specify anything when using Dora.
|
138 |
+
flashy.distrib.init()
|
139 |
+
solver = get_solver(cfg)
|
140 |
+
if cfg.show:
|
141 |
+
solver.show()
|
142 |
+
return
|
143 |
+
|
144 |
+
if cfg.execute_only:
|
145 |
+
assert cfg.execute_inplace or cfg.continue_from is not None, \
|
146 |
+
"Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \
|
147 |
+
"when running with execute_only or set execute_inplace to True."
|
148 |
+
solver.restore(replay_metrics=False) # load checkpoint
|
149 |
+
solver.run_one_stage(cfg.execute_only)
|
150 |
+
return
|
151 |
+
|
152 |
+
return solver.run()
|
153 |
+
|
154 |
+
|
155 |
+
main.dora.dir = AudioCraftEnvironment.get_dora_dir()
|
156 |
+
main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)
|
157 |
+
|
158 |
+
if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
|
159 |
+
print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
|
160 |
+
main.dora.shared = None
|
161 |
+
|
162 |
+
if __name__ == '__main__':
|
163 |
+
main()
|