llama / data /xtuner /docs /zh_cn /user_guides /sequence_parallel.md
kai119's picture
Upload folder using huggingface_hub
22fb4ec verified

A newer version of the Streamlit SDK is available: 1.44.0

Upgrade

序列并行:训练极长序列大模型的系统优化

XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 DeepSpeed Ulysses,并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。

简介

从生成性AI到科研模型,长序列训练正在变得非常重要。

在生成性AI领域,会话式AI、长文档摘要、代码库理解和例如 Sora 这种视频生成任务都需要在空间和时间层面对长上下文进行推理。

对于科学AI来说,长序列同样至关重要,它为更好地理解结构生物学、医疗保健、气候和天气预测以及大分子模拟打开了大门。

然而,尽管序列长度的重要性不断增长,XTuner 现有的显存优化策略(如 zero 系列),却不足以解决大模型、长序列训练问题。

同时,受限于通信效率,现有的许多序列并行方法也不够高效。

另外,现有的序列并行方法普遍存在较多的代码侵入式修改,易用性和维护性都要大打折扣。同时也不满足 XTuner 基于 transformers 算法库或 Huggingface Hub 上的开源模型直接进行训练的要求。



为了解决上述长序列训练带来的问题,XTuner 采用了一种简单、易用且高效的序列并行算法。由于 Transformer 结构较为规整,除 attention 计算外,其他计算过程中 token 之间不会互相影响(即每个 token 的计算是独立的),这一条件为序列并行提供了有利条件。上图展示了序列并行的核心设计。设由 P 个 GPUs 共同计算一个长度为 N 的长序列,在 Attention 计算的第一阶段,长度为 N / P 的子序列会通过线性层投影为 Query、Key、Value。接下来, QKV Tensor 会在参与序列并行计算的多个 GPUs 之间通过高度优化的 all-to-all 通信算子汇聚,得到序列长度为 N ,但更少注意力头的子序列。注意力计算后,通过另一个 all-to-all 通信算子将其转换为长度为 N / P 的子序列,进行后续计算。

总体而言,XTuner 的序列并行算法具有以下关键特性:

  • 支持全量训练超过百万个token的序列
  • 支持百 B 级模型训练:XTuner 的序列并行不仅支持长序列训练,还可结合 zero3 显存优化策略训练大尺寸模型
  • 完全通用的序列并行 API 抽象

使用 XTuner 进行序列并行训练

Step 1 修改 config 文件

  1. 在 config 中修改 sequence_parallel_size 字段即可调整 $sequence\_parallel\_world\_size$ 。
  2. 同时若想保证与不使用序列并行的训练效果类似,需要同步增大梯度累积的数值为原来的 $sequence\_parallel\_world\_size$ 倍,因为在使用序列并行训练时, $data\_parallel\_world\_size$ 降为了原来的 $\frac{1}{sequence\_parallel\_world\_size}$。
  3. 替换 DefaultSampler 为支持序列并行的 SequenceParallelSampler。

注:需要保证所使用的 GPU 总数可以被 sequence_parallel_size 整除。

+ from xtuner.parallel.sequence import SequenceParallelSampler

- sequence_parallel_size = 1
+ sequence_parallel_size = 4  # take `sequence_parallel_size = 4`` as an example

- accumulative_counts = 1
+ accumulative_counts = 4  # accumulative_counts = accumulative_counts * sequence_parallel_size

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
train_dataloader = dict(
-   sampler=dict(type=DefaultSampler, shuffle=True),
+   sampler=dict(type=SequenceParallelSampler, seed=1024, shuffle=True),
    ...)

另外,若需要进一步拓展模型的长文本处理能力,需要进一步修改 config 中的 max_position_embeddings 字段。例如需要将模型的上下文长度拓展为 64K 时,可进行如下修改:

+ max_position_embeddings = 65536

#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
model = dict(
    type=SupervisedFinetune,
+   max_position_embeddings = max_position_embeddings,
    ...)

Step 2 开始训练

需要使用 DeepSpeed 进行训练:

(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train ${CONFIG_PATH} --deepspeed deepspeed_zero2
(SLURM) srun ${SRUN_ARGS} xtuner train ${CONFIG_PATH} --launcher slurm --deepspeed deepspeed_zero2
  • ${CONFIG_PATH} 为 Step 1 中修改得到的 config 文件路径
  • 可根据实际情况选择使用不同的 zero 策略

序列并行 API 抽象

为了提升算法的可迁移性,XTuner 中抽象出了序列并行所必须的五个 API 接口:

  • 序列并行分布式环境初始化 (init_sequence_parallel)
  • 适配序列并行的 Data Sampler (SequenceParallelSampler)
  • 数据 Pad (pad_for_sequence_parallel)
  • 数据切分 (split_for_sequence_parallel)
  • 适配序列并行的 Attention (dispatch_modules)
  • reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss)

序列并行分布式环境初始化

由于序列并行算法会将长序列切分为 $sequence\_parallel\_world\_size$ 块,并将每个子序列分发给对应的 GPU 独立进行计算。因此需要在训练开始前初始化序列并行分布式环境,以指定哪几块 GPU 共同负责一个长序列输入的计算。

一个 $sequence\_parallel\_world\_size = 4$ 的示例如下:

# We have to initialize the distributed training environment first.
# Here is an example when training on slurm scheduler
# from xtuner.parallel.sequence import init_dist
# init_dist('slurm', 'nccl', init_backend='deepspeed')
from xtuner.parallel.sequence import init_sequence_parallel
sequence_parallel_world_size = 4
init_sequence_parallel(sequence_parallel_world_size)

上述过程在 xtuner/engine/_strategy/deepspeed.py 中实现。

Data Sampler 适配序列并行

在使用序列并行后,Dataloader 的采样策略需要进一步调整。例如当 $sequence\_parallel\_world\_size = 4$ 时,4 块 GPU 从 Dataloader 拿到的数据需要是完全一样的。

在构建 Dataloader 时搭配 XTuner 中提供的 SequenceParallelSampler 使用即可:

from xtuner.parallel.sequence import SequenceParallelSampler
dataloader = DataLoader(
    train_dataset, sampler=SequenceParallelSampler(train_dataset),
    **other_dataloader_params)

数据 Pad

由于每条训练数据的长度可能不尽相同,我们需要将数据进行 Pad 以使得序列长度可以被 $sequence\_parallel\_world\_size$ 整除,这样一条长数据才能被均等地分发给不同的 GPU 上。

训练过程中需要被 Pad 的 Tensor 往往有 input_ids, labels, position_ids, attention_mask 四个,pad 的过程可以通过以下方式实现:

from xtuner.parallel.sequence import pad_for_sequence_parallel

input_ids = pad_for_sequence_parallel(input_ids, padding_value=0)
labels = pad_for_sequence_parallel(labels, padding_value=-100)
position_ids = pad_for_sequence_parallel(position_ids, padding_value=0)
attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0)

以上过程在 xtuner/dataset/collate_fns/default_collate_fn.py 中实现。

数据切分

在传入给 Transformer 模型前,我们需要对长序列均等切分:

from xtuner.parallel.sequence import split_for_sequence_parallel
# attention mask should not be split
# `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
input_ids = split_for_sequence_parallel(input_ids, dim=1)
labels = split_for_sequence_parallel(labels, dim=1)
position_ids = split_for_sequence_parallel(position_ids, dim=1)

以上过程在 xtuner/model/sft.py 中实现。

Attention 适配序列并行

在 Attention 的计算过程中,序列中的不同 token 是不能独立运算的,但不同的 attention head 之间的计算却是独立的。因此,如第一节所述,需要在计算 Attention 前后(即 qkv_proj 后和 o_proj 前)分别插入一个 all-to-all 操作。

XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计算方式:

from xtuner.model.modules import dispatch_modules
model: AutoModelForCausalLM
dispatch_modules(model)

上述过程在 xtuner/model/sft.py 中实现。

Reduce Loss 以正确打印训练损失

这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。

from xtuner.parallel.sequence import reduce_sequence_parallel_loss
outputs = llm(input_ids=input_ids, labels=labels, **kwargs)
num_tokens_per_rank = (labels != -100).sum()
# Suppose sequence parallel world size equals to 4,
# losses on rank0, rank1, rank2, rank3 are different.
loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank)
# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same.

上述过程在 xtuner/model/sft.py 中实现。