File size: 2,701 Bytes
fe84f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Run training from Slurm on all visible GPUs. Start only
one task per node as this script will spawn one child for each GPU.
This will not schedule a job but instead should be launched from srun/sbatch.
"""
import os
import subprocess as sp
import sys
import time

import torch as th

from demucs.utils import free_port


def main():
    args = sys.argv[1:]
    gpus = th.cuda.device_count()
    n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
    node_id = int(os.environ['SLURM_NODEID'])
    job_id = int(os.environ['SLURM_JOBID'])

    rank_offset = gpus * node_id
    hostnames = sp.run(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']],
                       capture_output=True,
                       check=True).stdout
    master_addr = hostnames.split()[0].decode('utf-8')

    if n_nodes == 1:
        port = free_port()
    else:
        port = 20_000 + (job_id % 40_000)
    args += ["--world_size", str(n_nodes * gpus), "--master", f"{master_addr}:{port}"]
    tasks = []

    print("About to go live", master_addr, node_id, n_nodes, file=sys.stderr)
    sys.stderr.flush()

    for gpu in range(gpus):
        kwargs = {}
        if gpu > 0:
            kwargs['stdin'] = sp.DEVNULL
            kwargs['stdout'] = sp.DEVNULL
            # We keep stderr to see tracebacks from children.
        tasks.append(
            sp.Popen(["python3", "-m", "demucs"] + args +
                     ["--rank", str(rank_offset + gpu)], **kwargs))
        tasks[-1].rank = rank_offset + gpu

    failed = False
    try:
        while tasks:
            for task in tasks:
                try:
                    exitcode = task.wait(0.1)
                except sp.TimeoutExpired:
                    continue
                else:
                    tasks.remove(task)
                    if exitcode:
                        print(f"Task {task.rank} died with exit code "
                              f"{exitcode}",
                              file=sys.stderr)
                        failed = True
                    else:
                        print(f"Task {task.rank} exited successfully")
            if failed:
                break
            time.sleep(1)
    except KeyboardInterrupt:
        for task in tasks:
            task.terminate()
        raise
    if failed:
        for task in tasks:
            task.terminate()

        sp.run(["scancel", str(job_id)], check=True)
        sys.exit(1)


if __name__ == "__main__":
    main()