File size: 7,729 Bytes
424a0e8 961b50c 424a0e8 961b50c 424a0e8 961b50c 424a0e8 961b50c 424a0e8 961b50c 424a0e8 961b50c 424a0e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from bitsandbytes.optim import PagedAdamW32bit
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from modelscope.msdatasets import MsDataset
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
from xtuner.dataset import process_ms_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import (msagent_react_map_fn,
template_map_fn_factory)
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'internlm/internlm-7b'
# Data
data_path = 'damo/MSAgent-Bench'
prompt_template = PROMPT_TEMPLATE.default
max_length = 2048
pack_to_max_length = False
# Scheduler & Optimizer
batch_size = 8 # per_device
accumulative_counts = 1
dataloader_num_workers = 2
max_epochs = 3
optim_type = PagedAdamW32bit
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = (
'你是一个可以调用外部工具的助手,可以使用的工具包括:\n'
"{{\'GoogleSearch\': \'一个可以从谷歌搜索结果的API。\\n"
'当你需要对于一个特定问题找到简短明了的回答时,可以使用它。\\n'
"输入应该是一个搜索查询。\\n\\n\',"
"\'PythonInterpreter\': \"用来执行Python代码。代码必须是一个函数,\\n"
"函数名必须得是 \'solution\',代码对应你的思考过程。代码实例格式如下:\\n"
'```python\\n# import 依赖包\\nimport xxx\\ndef solution():'
'\\n # 初始化一些变量\\n variable_names_with_real_meaning = xxx'
'\\n # 步骤一\\n mid_variable = func(variable_names_with_real_meaning)'
'\\n # 步骤 x\\n mid_variable = func(mid_variable)\\n # 最后结果'
'\\n final_answer = func(mid_variable)\\n return final_answer'
"\\n```\\n\"}}\n"
'如果使用工具请遵循以下格式回复:\n```\n'
'Thought:思考你当前步骤需要解决什么问题,是否需要使用工具\n'
"Action:工具名称,你的工具必须从 [[\'GoogleSearch\', \'PythonInterpreter\']] 选择"
'\nAction Input:工具输入参数\n```\n工具返回按照以下格式回复:\n'
'```\nResponse:调用工具后的结果\n```'
'\n如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复\n```'
'\nThought:给出最终答案的思考过程\nFinal Answer:最终答案\n```\n开始!\n')
evaluation_inputs = ['上海明天天气怎么样?']
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')
model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
type=process_ms_dataset,
dataset=dict(type=MsDataset.load, dataset_name=data_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=msagent_react_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
|