Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import List, Optional, Union | |
import yaml | |
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType | |
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION | |
hf_cache_home = os.path.expanduser( | |
os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface")) | |
) | |
cache_dir = os.path.join(hf_cache_home, "accelerate") | |
default_json_config_file = os.path.join(cache_dir, "default_config.yaml") | |
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml") | |
# For backward compatibility: the default config is the json one if it's the only existing file. | |
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file): | |
default_config_file = default_yaml_config_file | |
else: | |
default_config_file = default_json_config_file | |
def load_config_from_file(config_file): | |
if config_file is not None: | |
if not os.path.isfile(config_file): | |
raise FileNotFoundError( | |
f"The passed configuration file `{config_file}` does not exist. " | |
"Please pass an existing file to `accelerate launch`, or use the default one " | |
"created through `accelerate config` and run `accelerate launch` " | |
"without the `--config_file` argument." | |
) | |
else: | |
config_file = default_config_file | |
with open(config_file, encoding="utf-8") as f: | |
if config_file.endswith(".json"): | |
if ( | |
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) | |
== ComputeEnvironment.LOCAL_MACHINE | |
): | |
config_class = ClusterConfig | |
else: | |
config_class = SageMakerConfig | |
return config_class.from_json_file(json_file=config_file) | |
else: | |
if ( | |
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) | |
== ComputeEnvironment.LOCAL_MACHINE | |
): | |
config_class = ClusterConfig | |
else: | |
config_class = SageMakerConfig | |
return config_class.from_yaml_file(yaml_file=config_file) | |
class BaseConfig: | |
compute_environment: ComputeEnvironment | |
distributed_type: Union[DistributedType, SageMakerDistributedType] | |
mixed_precision: str | |
use_cpu: bool | |
debug: bool | |
def to_dict(self): | |
result = self.__dict__ | |
# For serialization, it's best to convert Enums to strings (or their underlying value type). | |
for key, value in result.items(): | |
if isinstance(value, Enum): | |
result[key] = value.value | |
if isinstance(value, dict) and not bool(value): | |
result[key] = None | |
result = {k: v for k, v in result.items() if v is not None} | |
return result | |
def from_json_file(cls, json_file=None): | |
json_file = default_json_config_file if json_file is None else json_file | |
with open(json_file, encoding="utf-8") as f: | |
config_dict = json.load(f) | |
if "compute_environment" not in config_dict: | |
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE | |
if "mixed_precision" not in config_dict: | |
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None | |
if "fp16" in config_dict: # Convert the config to the new format. | |
del config_dict["fp16"] | |
if "dynamo_backend" in config_dict: # Convert the config to the new format. | |
dynamo_backend = config_dict.pop("dynamo_backend") | |
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} | |
if "use_cpu" not in config_dict: | |
config_dict["use_cpu"] = False | |
if "debug" not in config_dict: | |
config_dict["debug"] = False | |
if "enable_cpu_affinity" not in config_dict: | |
config_dict["enable_cpu_affinity"] = False | |
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) | |
if len(extra_keys) > 0: | |
raise ValueError( | |
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" | |
" version or fix (and potentially remove) these keys from your config file." | |
) | |
return cls(**config_dict) | |
def to_json_file(self, json_file): | |
with open(json_file, "w", encoding="utf-8") as f: | |
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |
f.write(content) | |
def from_yaml_file(cls, yaml_file=None): | |
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file | |
with open(yaml_file, encoding="utf-8") as f: | |
config_dict = yaml.safe_load(f) | |
if "compute_environment" not in config_dict: | |
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE | |
if "mixed_precision" not in config_dict: | |
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None | |
if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]: | |
config_dict["mixed_precision"] = "no" | |
if "fp16" in config_dict: # Convert the config to the new format. | |
del config_dict["fp16"] | |
if "dynamo_backend" in config_dict: # Convert the config to the new format. | |
dynamo_backend = config_dict.pop("dynamo_backend") | |
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} | |
if "use_cpu" not in config_dict: | |
config_dict["use_cpu"] = False | |
if "debug" not in config_dict: | |
config_dict["debug"] = False | |
if "enable_cpu_affinity" not in config_dict: | |
config_dict["enable_cpu_affinity"] = False | |
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) | |
if len(extra_keys) > 0: | |
raise ValueError( | |
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" | |
" version or fix (and potentially remove) these keys from your config file." | |
) | |
return cls(**config_dict) | |
def to_yaml_file(self, yaml_file): | |
with open(yaml_file, "w", encoding="utf-8") as f: | |
yaml.safe_dump(self.to_dict(), f) | |
def __post_init__(self): | |
if isinstance(self.compute_environment, str): | |
self.compute_environment = ComputeEnvironment(self.compute_environment) | |
if isinstance(self.distributed_type, str): | |
if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: | |
self.distributed_type = SageMakerDistributedType(self.distributed_type) | |
else: | |
self.distributed_type = DistributedType(self.distributed_type) | |
if getattr(self, "dynamo_config", None) is None: | |
self.dynamo_config = {} | |
class ClusterConfig(BaseConfig): | |
num_processes: int | |
machine_rank: int = 0 | |
num_machines: int = 1 | |
gpu_ids: Optional[str] = None | |
main_process_ip: Optional[str] = None | |
main_process_port: Optional[int] = None | |
rdzv_backend: Optional[str] = "static" | |
same_network: Optional[bool] = False | |
main_training_function: str = "main" | |
enable_cpu_affinity: bool = False | |
# args for deepspeed_plugin | |
deepspeed_config: dict = None | |
# args for fsdp | |
fsdp_config: dict = None | |
# args for megatron_lm | |
megatron_lm_config: dict = None | |
# args for ipex | |
ipex_config: dict = None | |
# args for mpirun | |
mpirun_config: dict = None | |
# args for TPU | |
downcast_bf16: bool = False | |
# args for TPU pods | |
tpu_name: str = None | |
tpu_zone: str = None | |
tpu_use_cluster: bool = False | |
tpu_use_sudo: bool = False | |
command_file: str = None | |
commands: List[str] = None | |
tpu_vm: List[str] = None | |
tpu_env: List[str] = None | |
# args for dynamo | |
dynamo_config: dict = None | |
def __post_init__(self): | |
if self.deepspeed_config is None: | |
self.deepspeed_config = {} | |
if self.fsdp_config is None: | |
self.fsdp_config = {} | |
if self.megatron_lm_config is None: | |
self.megatron_lm_config = {} | |
if self.ipex_config is None: | |
self.ipex_config = {} | |
if self.mpirun_config is None: | |
self.mpirun_config = {} | |
return super().__post_init__() | |
class SageMakerConfig(BaseConfig): | |
ec2_instance_type: str | |
iam_role_name: str | |
image_uri: Optional[str] = None | |
profile: Optional[str] = None | |
region: str = "us-east-1" | |
num_machines: int = 1 | |
gpu_ids: str = "all" | |
base_job_name: str = f"accelerate-sagemaker-{num_machines}" | |
pytorch_version: str = SAGEMAKER_PYTORCH_VERSION | |
transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION | |
py_version: str = SAGEMAKER_PYTHON_VERSION | |
sagemaker_inputs_file: str = None | |
sagemaker_metrics_file: str = None | |
additional_args: dict = None | |
dynamo_config: dict = None | |