Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Utility functions for SLURM configuration and cluster settings. | |
""" | |
from enum import Enum | |
import os | |
import socket | |
import typing as tp | |
import omegaconf | |
class ClusterType(Enum): | |
AWS = "aws" | |
FAIR = "fair" | |
RSC = "rsc" | |
LOCAL_DARWIN = "darwin" | |
DEFAULT = "default" # used for any other cluster. | |
def _guess_cluster_type() -> ClusterType: | |
uname = os.uname() | |
fqdn = socket.getfqdn() | |
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): | |
return ClusterType.AWS | |
if fqdn.endswith(".fair"): | |
return ClusterType.FAIR | |
if fqdn.endswith(".facebook.com"): | |
return ClusterType.RSC | |
if uname.sysname == "Darwin": | |
return ClusterType.LOCAL_DARWIN | |
return ClusterType.DEFAULT | |
def get_cluster_type( | |
cluster_type: tp.Optional[ClusterType] = None, | |
) -> tp.Optional[ClusterType]: | |
if cluster_type is None: | |
return _guess_cluster_type() | |
return cluster_type | |
def get_slurm_parameters( | |
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None | |
) -> omegaconf.DictConfig: | |
"""Update SLURM parameters in configuration based on cluster type. | |
If the cluster type is not specify, it infers it automatically. | |
""" | |
from ..environment import AudioCraftEnvironment | |
cluster_type = get_cluster_type(cluster_type) | |
# apply cluster-specific adjustments | |
if cluster_type == ClusterType.AWS: | |
cfg["mem_per_gpu"] = None | |
cfg["constraint"] = None | |
cfg["setup"] = [] | |
elif cluster_type == ClusterType.RSC: | |
cfg["mem_per_gpu"] = None | |
cfg["setup"] = [] | |
cfg["constraint"] = None | |
cfg["partition"] = "learn" | |
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() | |
if slurm_exclude is not None: | |
cfg["exclude"] = slurm_exclude | |
return cfg | |