mini-omni-s2s / slam_llm /utils /config_utils.py
xcczach's picture
Upload 73 files
35c1cfd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import inspect
# from dataclasses import asdict
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from peft import (
LoraConfig,
AdaptionPromptConfig,
PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq
# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from omegaconf import OmegaConf
import logging
logger = logging.getLogger(__name__)
# def update_config(config, **kwargs):
# if isinstance(config, (tuple, list)):
# for c in config:
# update_config(c, **kwargs)
# else:
# for k, v in kwargs.items():
# if hasattr(config, k):
# setattr(config, k, v)
# elif "." in k:
# # allow --some_config.some_param=True
# config_name, param_name = k.split(".")
# if type(config).__name__ == config_name:
# if hasattr(config, param_name):
# setattr(config, param_name, v)
# else:
# # In case of specialized config we can warm user
# logger.warning(f"Warning: {config_name} does not accept parameter: {k}")
# elif isinstance(config, train_config):
# logger.warning(f"Warning: unknown parameter {k}")
def generate_peft_config(train_config):
# configs = (lora_config, llama_adapter_config, prefix_config)
# peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
peft_configs = {"lora": LoraConfig,
"llama_adapter": AdaptionPromptConfig,
"prefix": PrefixTuningConfig
}
# names = tuple(c.__name__.rstrip("_config") for c in configs)
#
# assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
#
# config = configs[names.index(train_config.peft_method)]()
config = train_config.peft_config
params = OmegaConf.to_container(config, resolve=True)
# peft_config = peft_configs[names.index(train_config.peft_method)](**params)
params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error
peft_config = peft_configs[config.get("peft_method", "lora")](**params)
return peft_config
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
if train_config.batching_strategy == "padding":
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
else:
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
# raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = dataset.collator
logger.info(f"Using batching strategy: {train_config.batching_strategy}")
return kwargs