|
|
|
|
|
|
|
|
|
|
|
"""Resampling script.
|
|
"""
|
|
import argparse
|
|
from pathlib import Path
|
|
import shutil
|
|
import typing as tp
|
|
|
|
import submitit
|
|
import tqdm
|
|
|
|
from audiocraft.data.audio import audio_read, audio_write
|
|
from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files
|
|
from audiocraft.data.audio_utils import convert_audio
|
|
from audiocraft.environment import AudioCraftEnvironment
|
|
|
|
|
|
def read_txt_files(path: tp.Union[str, Path]):
|
|
with open(args.files_path) as f:
|
|
lines = [line.rstrip() for line in f]
|
|
print(f"Read {len(lines)} in .txt")
|
|
lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']]
|
|
print(f"Filtered and keep {len(lines)} from .txt")
|
|
return lines
|
|
|
|
|
|
def read_egs_files(path: tp.Union[str, Path]):
|
|
path = Path(path)
|
|
if path.is_dir():
|
|
if (path / 'data.jsonl').exists():
|
|
path = path / 'data.jsonl'
|
|
elif (path / 'data.jsonl.gz').exists():
|
|
path = path / 'data.jsonl.gz'
|
|
else:
|
|
raise ValueError("Don't know where to read metadata from in the dir. "
|
|
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
|
meta = load_audio_meta(path)
|
|
return [m.path for m in meta]
|
|
|
|
|
|
def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None):
|
|
if task_index is None:
|
|
env = submitit.JobEnvironment()
|
|
task_index = env.global_rank
|
|
shard_index = node_index * args.tasks_per_node + task_index
|
|
|
|
if args.files_path is None:
|
|
lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)]
|
|
else:
|
|
files_path = Path(args.files_path)
|
|
if files_path.suffix == '.txt':
|
|
print(f"Reading file list from .txt file: {args.files_path}")
|
|
lines = read_txt_files(args.files_path)
|
|
else:
|
|
print(f"Reading file list from egs: {args.files_path}")
|
|
lines = read_egs_files(args.files_path)
|
|
|
|
total_files = len(lines)
|
|
print(
|
|
f"Total of {total_files} processed with {n_shards} shards. " +
|
|
f"Current idx = {shard_index} -> {total_files // n_shards} files to process"
|
|
)
|
|
for idx, line in tqdm.tqdm(enumerate(lines)):
|
|
|
|
|
|
if idx % n_shards != shard_index:
|
|
continue
|
|
|
|
path = str(AudioCraftEnvironment.apply_dataset_mappers(line))
|
|
root_path = str(args.root_path)
|
|
if not root_path.endswith('/'):
|
|
root_path += '/'
|
|
assert path.startswith(str(root_path)), \
|
|
f"Mismatch between path and provided root: {path} VS {root_path}"
|
|
|
|
try:
|
|
metadata_path = Path(path).with_suffix('.json')
|
|
out_path = args.out_path / path[len(root_path):]
|
|
out_metadata_path = out_path.with_suffix('.json')
|
|
out_done_token = out_path.with_suffix('.done')
|
|
|
|
|
|
if out_done_token.exists():
|
|
continue
|
|
|
|
print(idx, out_path, path)
|
|
mix, sr = audio_read(path)
|
|
mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0)
|
|
|
|
out_channels = mix_channels
|
|
if out_channels > 2:
|
|
print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels")
|
|
out_channels = 2
|
|
out_sr = args.sample_rate if args.sample_rate is not None else sr
|
|
out_wav = convert_audio(mix, sr, out_sr, out_channels)
|
|
audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr,
|
|
format=args.format, normalize=False, strategy='clip')
|
|
if metadata_path.exists():
|
|
shutil.copy(metadata_path, out_metadata_path)
|
|
else:
|
|
print(f"No metadata found at {str(metadata_path)}")
|
|
out_done_token.touch()
|
|
except Exception as e:
|
|
print(f"Error processing file line: {line}, {e}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Resample dataset with SLURM.")
|
|
parser.add_argument(
|
|
"--log_root",
|
|
type=Path,
|
|
default=Path.home() / 'tmp' / 'resample_logs',
|
|
)
|
|
parser.add_argument(
|
|
"--files_path",
|
|
type=Path,
|
|
help="List of files to process, either .txt (one file per line) or a jsonl[.gz].",
|
|
)
|
|
parser.add_argument(
|
|
"--root_path",
|
|
type=Path,
|
|
required=True,
|
|
help="When rewriting paths, this will be the prefix to remove.",
|
|
)
|
|
parser.add_argument(
|
|
"--out_path",
|
|
type=Path,
|
|
required=True,
|
|
help="When rewriting paths, `root_path` will be replaced by this.",
|
|
)
|
|
parser.add_argument("--xp_name", type=str, default="shutterstock")
|
|
parser.add_argument(
|
|
"--nodes",
|
|
type=int,
|
|
default=4,
|
|
)
|
|
parser.add_argument(
|
|
"--tasks_per_node",
|
|
type=int,
|
|
default=20,
|
|
)
|
|
parser.add_argument(
|
|
"--cpus_per_task",
|
|
type=int,
|
|
default=4,
|
|
)
|
|
parser.add_argument(
|
|
"--memory_gb",
|
|
type=int,
|
|
help="Memory in GB."
|
|
)
|
|
parser.add_argument(
|
|
"--format",
|
|
type=str,
|
|
default="wav",
|
|
)
|
|
parser.add_argument(
|
|
"--sample_rate",
|
|
type=int,
|
|
default=32000,
|
|
)
|
|
parser.add_argument(
|
|
"--channels",
|
|
type=int,
|
|
)
|
|
parser.add_argument(
|
|
"--partition",
|
|
default='learnfair',
|
|
)
|
|
parser.add_argument("--qos")
|
|
parser.add_argument("--account")
|
|
parser.add_argument("--timeout", type=int, default=4320)
|
|
parser.add_argument('--debug', action='store_true', help='debug mode (local run)')
|
|
args = parser.parse_args()
|
|
n_shards = args.tasks_per_node * args.nodes
|
|
if args.files_path is None:
|
|
print("Warning: --files_path not provided, not recommended when processing more than 10k files.")
|
|
if args.debug:
|
|
print("Debugging mode")
|
|
process_dataset(args, n_shards=n_shards, node_index=0, task_index=0)
|
|
else:
|
|
|
|
log_folder = Path(args.log_root) / args.xp_name / '%j'
|
|
print(f"Logging to: {log_folder}")
|
|
log_folder.parent.mkdir(parents=True, exist_ok=True)
|
|
executor = submitit.AutoExecutor(folder=str(log_folder))
|
|
if args.qos:
|
|
executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account)
|
|
else:
|
|
executor.update_parameters(slurm_partition=args.partition)
|
|
executor.update_parameters(
|
|
slurm_job_name=args.xp_name, timeout_min=args.timeout,
|
|
cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1)
|
|
if args.memory_gb:
|
|
executor.update_parameters(mem=f'{args.memory_gb}GB')
|
|
jobs = []
|
|
with executor.batch():
|
|
for node_index in range(args.nodes):
|
|
job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index)
|
|
jobs.append(job)
|
|
for job in jobs:
|
|
print(f"Waiting on job {job.job_id}")
|
|
job.results()
|
|
|