from .constants import ( MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SCALER_NAME, SCHEDULER_NAME, TORCH_LAUNCH_PARAMS, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) from .dataclasses import ( BnbQuantizationConfig, ComputeEnvironment, CustomDtype, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, DynamoBackend, FP8RecipeKwargs, FullyShardedDataParallelPlugin, GradientAccumulationPlugin, GradScalerKwargs, InitProcessGroupKwargs, KwargsHandler, LoggerType, MegatronLMPlugin, PrecisionType, ProjectConfiguration, RNGType, SageMakerDistributedType, TensorInformation, TorchDynamoPlugin, ) from .environment import get_int_from_env, parse_choice_from_env, parse_flag_from_env from .imports import ( get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, is_aim_available, is_bf16_available, is_bnb_available, is_boto3_available, is_ccl_available, is_comet_ml_available, is_datasets_available, is_deepspeed_available, is_fp8_available, is_ipex_available, is_megatron_lm_available, is_mlflow_available, is_mps_available, is_npu_available, is_rich_available, is_safetensors_available, is_sagemaker_available, is_tensorboard_available, is_tpu_available, is_transformers_available, is_wandb_available, is_xpu_available, ) from .modeling import ( check_device_map, check_tied_parameters_in_config, check_tied_parameters_on_same_device, compute_module_sizes, convert_file_size_to_int, dtype_byte_size, find_tied_parameters, get_balanced_memory, get_max_layer_size, get_max_memory, get_mixed_precision_context_manager, id_tensor_storage, infer_auto_device_map, load_checkpoint_in_model, load_offloaded_weights, load_state_dict, named_module_tensors, retie_parameters, set_module_tensor_to_device, shard_checkpoint, ) from .offload import ( OffloadedWeightsLoader, PrefixedDataset, extract_submodules_state_dict, load_offloaded_weight, offload_state_dict, offload_weight, save_offload_index, ) from .operations import ( broadcast, broadcast_object_list, concatenate, convert_outputs_to_fp32, convert_to_fp32, find_batch_size, find_device, gather, gather_object, get_data_structure, honor_type, initialize_tensors, is_namedtuple, is_tensor_information, is_torch_tensor, listify, pad_across_processes, recursively_apply, reduce, send_to_device, slice_tensors, ) from .versions import compare_versions, is_torch_version if is_deepspeed_available(): from .deepspeed import ( DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper, DeepSpeedSchedulerWrapper, DummyOptim, DummyScheduler, HfDeepSpeedConfig, ) from .bnb import has_4bit_bnb_layers, load_and_quantize_model from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer from .launch import ( PrepareForLaunch, _filter_args, prepare_deepspeed_cmd_env, prepare_multi_gpu_env, prepare_sagemager_args_inputs, prepare_simple_launcher_cmd_env, prepare_tpu, ) from .megatron_lm import ( AbstractTrainStep, BertTrainStep, GPTTrainStep, MegatronEngine, MegatronLMDummyDataLoader, MegatronLMDummyScheduler, MegatronLMOptimizerWrapper, MegatronLMSchedulerWrapper, T5TrainStep, avg_losses_across_data_parallel_group, gather_across_data_parallel_groups, ) from .megatron_lm import initialize as megatron_lm_initialize from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader from .megatron_lm import prepare_model as megatron_lm_prepare_model from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler from .memory import find_executable_batch_size, release_memory from .other import ( extract_model_from_parallel, get_pretty_name, is_port_in_use, merge_dicts, patch_environment, save, wait_for_everyone, write_basic_config, ) from .random import set_seed, synchronize_rng_state, synchronize_rng_states from .torch_xla import install_xla from .tqdm import tqdm from .transformer_engine import convert_model, has_transformer_engine_layers