File size: 4,843 Bytes
35c1cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import inspect
# from dataclasses import asdict

import torch.distributed as dist
from torch.utils.data import DistributedSampler
from peft import (
    LoraConfig,
    AdaptionPromptConfig,
    PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler

from omegaconf import OmegaConf

import logging
logger = logging.getLogger(__name__)

# def update_config(config, **kwargs):
#     if isinstance(config, (tuple, list)):
#         for c in config:
#             update_config(c, **kwargs)
#     else:
#         for k, v in kwargs.items():
#             if hasattr(config, k):
#                 setattr(config, k, v)
#             elif "." in k:
#                 # allow --some_config.some_param=True
#                 config_name, param_name = k.split(".")
#                 if type(config).__name__ == config_name:
#                     if hasattr(config, param_name):
#                         setattr(config, param_name, v)
#                     else:
#                         # In case of specialized config we can warm user
#                         logger.warning(f"Warning: {config_name} does not accept parameter: {k}")
#             elif isinstance(config, train_config):
#                 logger.warning(f"Warning: unknown parameter {k}")


def generate_peft_config(train_config):
    # configs = (lora_config, llama_adapter_config, prefix_config)
    # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
    peft_configs = {"lora": LoraConfig,
                    "llama_adapter": AdaptionPromptConfig,
                    "prefix": PrefixTuningConfig
                    }
    # names = tuple(c.__name__.rstrip("_config") for c in configs)
    #
    # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
    #
    # config = configs[names.index(train_config.peft_method)]()
    config = train_config.peft_config

    params = OmegaConf.to_container(config, resolve=True)
    # peft_config = peft_configs[names.index(train_config.peft_method)](**params)
    params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error
    peft_config = peft_configs[config.get("peft_method", "lora")](**params)

    return peft_config


def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
        kwargs = {}
        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
        if train_config.batching_strategy == "padding":
            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                    dataset,
                    batch_size=batch_size,
                    rank=dist.get_rank(),
                    num_replicas=dist.get_world_size(),
                    shuffle=mode=="train",
                )
            else:
                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
        elif train_config.batching_strategy == "packing":
            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
                kwargs["sampler"] = DistributedSampler(
                dataset,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
                shuffle=mode=="train",
            )
            kwargs["batch_size"] = batch_size
            kwargs["drop_last"] = True
            kwargs["collate_fn"] = default_data_collator
        else:
            # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
                kwargs["sampler"] = DistributedSampler(
                dataset,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
                shuffle=mode=="train",
            )
            kwargs["batch_size"] = batch_size
            kwargs["drop_last"] = True
            kwargs["collate_fn"] = dataset.collator
            logger.info(f"Using batching strategy: {train_config.batching_strategy}")

        return kwargs