(7)LLM分布式训练面面观
开篇
大家好,我是小A。前面几篇我们感受了Transformer在LLM、多模态和AIGC等算法领域的应用,从这一篇开始我们把目光转向工程部分。本文将围绕LLM里面分布式训练,回答下面4个问题
- 分布式通信原语有哪些?
- 3D并行都是怎么做的?
- ZeRO是如何降低数据并行的显存占用的?
- Alpa是如何用自动化搜索方式逼近人工设计的最优方案的?
PS: 长文预警,本篇约2w字,关注&收藏后电脑上阅读体验更加哦~ (^_^)
分布式通信原语
LLM分布式训练需要用到多机多卡,通信是个最基础的问题。因此在正式介绍并行算法之前,我们先了解一下常用分布式通信源语(primitives)有哪些。
分布式通信一般有两种
- 点对点通信(Point-to-point Communication, P2P): 两个节点间通信
- 集合通信(Collective Communication, CC): 在一组节点内通信
使用最为广泛的分布式通信库有MPI和NCCL,在工程实践中通常CPU上用MPI,GPU上用NCCL
MPI
MPI(Message Passing Interface)全称是是消息传递接口,定义了一套接口,比较流行的开源实现是OpenMPI。常用于计算集群,超算上。针对CPU之间的通信做了很多优化。
目前MPI广泛用于节点的管理,但是GPU之间的通信优化得不好,MPI没有针对性的做优化
编译带mpi的c文件的时候,需要用 mpicxx
编译器
NCCL
NCCL(NVIDIA Collective Communication Library) 是NV专门针对自家的计算卡和网络的集合通信库,完成单机多卡和多机多卡的高速互联通信
这里repo包含了简单易懂的所有NCCL使用示例,感兴趣的可以移步学习。注意在编译nccl相关的cu文件时候,需要加上 -lnccl
的库
点对点通信原语
点对点的操作比较简单,只有发送和接收两个操作
- 发送的时候需要指定发送buffer和发送给哪个peer对象编号
- 接收的时候需要指定接收buffer和数据来自哪个peer对象编号
NCCL和MPI的接口如下所示
集合通信常用原语
集合通信是在一组节点内相互通信,在MPI中大致分成三种
- 一对多。例如Broadcast和Scatter等。
- 多对一。例如Reduce和Gather等
- 多对多。例如AllReduce、AllGather、ReduceScatter和AllToAll
值得注意的是在NCCL中,原语种类比MPI要少,只包含了5种,分别为Broadcast、Reduce、AllGather、ReduceScatter和AllReduce
Broadcast和Scatter
Broadcast是把同一份1份数据广播给N个节点,Scatter是把N份数据发给N个节点。他们的相同点都是从一个节点发给多个节点,区别就是Broadcast发送的是同一份数据,Scatter发送的是不同的数据
Reduce和Gather
- Reduce是搜集N个节点的数据,并且对这些数做reduce操作,例如SUM/MIN/MAX/PROD等
- Gather是搜集N个节点的数据,得到长度为N的数组
AllGather
每个节点各有1分独特数据,相互交换让N个节点有N份数据
- Naive实现: 通过gather+broadcast来实现
- SOTA实现: 直接并行发送
ReduceScatter
有N个节点,每个节点有长度为N的数据,首先做reduce得到长度为N的数据,然后scatter分发给N个节点,每个节点得到长度为1的数据
- Naive实现: 通过Reduce+Scatter实现
- SOTA实现: 直接并行发送
AllReduce
搜集N个节点数据,得到数组做reduce之后,再发给N个节点。
- Naive实现:通过Reduce+Broadcast来实现
- SOTA实现:ReduceScatter + AllGather
AllToAll
跟AllGather类似,但是AllGather是每个节点拿到的结果最终是一样的,而AllToAll的每个节点拿到的结果是不一样的
数据并行(DP)
数据并行(Data Parallel)是使用最为广泛的并行方法,加速性能非常好,原因是各个数据切片可以做到完全解耦,只需要在最后每个mini-batch结束的时候做一下梯度的all-reduce既可。
数据并行可以分为中心化方式的和无中心化方式的,对应于pytorch里面的DataParallel和DistributedDataParallel
这两种方式最大的区别是gradient和reduce计算过程
- DataParallel是要在forward之后把所有输出gather到0号卡上,计算完loss之后再scatter到各个设备上,然后做backward独立计算gradient,最后搜集gradient到0号卡。因此需要在forward和backward间插入一次通信
- DistributedDataParallel是每张卡独立的做forward和backward,然后对各卡的gradient做all-reduce。因此forward和backward间无需通信
流水线并行(PP)
当模型参数量太大,一张卡不能完全放下的情况下,就必须对模型进行切分了,流水线并行(Pipeline Parallel)就是一种切分方法。具体来说
- 沿着模型的拓扑序,切分成p 段,每一段为一个stage,因此可以形成逻辑上相互串联的p 个stage
- 将大小为N 的mini-batch进一步切分为m 个大小为M 的micro-batch,因此N = m \cdot M
- 这些micro-batch依次进入上述p 个stage
我们首先研究一下流水线并行切分之后,每个stage的耗时跟原来相比会有什么变化,如下所示
- 假设1:算力为1 的节点,处理完整模型的1个micro-batch,前向和反向耗时分别是t_f 和t_b
- 推论1:算力为1/p 的节点,处理完整模型的1个micro-batch,前向和反向耗时分别是p \cdot t_f 和p \cdot t_b
- 推论2:算力为1 的节点,处理1/p 模型的1个micro-batch,前向和反向耗时分别是t_f / p 和t_b/p
- 推论3:算力为1/p 的节点,处理1/p 模型的1个micro-batch,前向和反向耗时分别是t_f 和t_b
因此根据推论3可知,所有算力处理整个模型的耗时,跟1/p 的算力处理1/p 段模型的耗时,是一致的。
下面开始讨论各种策略下的耗时情况。
理论上界
我们先考虑理想最优情况。此时显存无限大,不需要pipeline并行,一把梭直接对mini-batch的样本做前向和反向,耗时正比于样本数量。由假设1不难算出b_\text{best} 耗时为
t_{\text{best}} = m \cdot (t_f+t_b)
朴素串行
下面考虑最朴素的串行方式,类似于CPU的单周期处理器时代。每个micro-batch串行逐个做前向和反向,如下所示
这里蓝色是前向过程,绿色是反向过程,粗略认为反向过程是前向的2倍耗时。此时由推论3知道一个micro-batch的耗时是p\cdot(t_f+t_b) ,那么1个mini-batch的耗时t_\text{sequential} 是
t_\text{sequential}=mp\cdot(t_f+t_b)
可见耗时是理论上界的$p$倍,其中有大量的计算资源在闲置空载,硬件利率用很低
Gpipe
接着考虑Gpipe的朴素流水线并行,类似于CPU的流水线处理器时代。所有micro-batch用p 级流水线并行做前向和反向,注意需要等所有micro-batch都计算完才能执行反向过程
分析发现此时总耗时t_\text{naive} 为
t_\text{naive}=p\cdot (t_f+t_b) + (m-1)\cdot(t_f+t_b)=(m+p-1)\cdot(t_f+t_b)
bubble分析
盗图一张,不难发现t\text{naive} 比t\text{best} 会多出不少耗时,这部分在上图中也能看出来,就是3个部分的气泡空腔bubble。bubble越多,代表硬件闲置越多,因此一般定义bubble ratio来衡量流水线算法对硬件的浪费程度,值越小说明流水线效率越高。
bubble ratio的计算有两种方式,物理含义就是空腔的面积占整体面积的比例
计算方式1:用t\text{naive} 和t\text{best} 来计算
\begin{equation} \text{bubble ratio} = \frac{t_\text{naive}-t_{\text{best}}}{t_\text{naive}} = \frac{(m+p-1)(t_f+t_b) - m(t_f+t_b)}{(m+p-1)(t_f+t_b)} = \frac{p-1}{m+p-1} \end{equation}
计算方式2:看图计算空腔比例,等价于1减去满载比例
\text{bubble ratio} = 1 - \frac{2\cdot pm}{p\cdot2(m+p-1)}=\frac{p-1}{m+p-1}
由上述公式可以看出,m 相比于p 越大,气泡空腔的面积就越小,经验上认为m \ge 4 p 的时候就基本可以忽略了。
Re-materialization
在Gpipe里面,为了减少显存的占用,可以使用checkpoint技术,也叫re-materialization技术。即只保留每个stage的输入activation,backward需要各层的activation的时候从stage开头重新forward一遍,因此一共forward两遍
PipeDream (Non-Interleaved 1F1B)
为了进一步压缩气泡空腔,可以继续增加micro-batch的个数m ,但是我们发现Gpipe的里面,所有的micro-batch要forward攒齐了才做backward,这里会缓存大量的中间activation结果。
因此类似于数据并行里面从DataParallel变成DistributedDataParallel,我们可以解耦同一个mini-batch的不同micro-batch,让他们独立的做forward和backward。只是要求在最后一个micro-batch完成backward之后对各个micro-batch的gradient取个平均后执行optimization_step。
PipeDream这样做的好处,是可以进一步降低显存的使用,从而把\frac{p-1}{m+p-1} 中的m 进一步放大,降低了气泡空腔的比例。
但值得注意的是,在PipeDream的原始论文里面,并行的颗粒度是mini-batch之间,而不是micro-batch之间。因此对于每个device而言,会存在多个mini-batch交替进行异forward-backward和梯度更新,因此需要存储多个权重的版本。
对PipeDream的显存进一步优化的工作有PipeDream-2BW和PipeMare,感兴趣的可自行
Interleaved 1F1B
Interleaved 1F1B是源自Megatron-LM中的方法,希望进一步提高m 来降低气泡空腔比率。思路是用前面的推论2,即对于每个device,固定算力,将micro-batch和流水线切分更细,使得每个device负责的流水线级数翻倍。这样跟原来相比,每个device上stage的耗时成倍减少的同时,stage数量也成倍增多。
举个具体的例子,如下所示
- 有4个micro-batch,每个micro-batch由blue和gray两部分组成,例如1由blue1(简称b1)和gray1(简称g1)组成
- 有4个device(简称d1-d4)
那么原来Non-Interleaved 1F1B中数据和device的对应关系为
\text{(b1,d1)-(g1,d1)-(b2,d2)-(g2,d2)-(b3,d3)-(g3,d3)-(b4,d4)-(g4,d4)}
新的Interleaved 1F1B的对应关系为
\text{(b1,d1)-(g1,d2)-(b2,d3)-(g2,d4)-(b3,d1)-(g3,d2)-(b4,d3)-(g4,d4)}
直观上想这样一增一减的折腾貌似没啥用,但仔细分析发现切分更细之后,原来的气泡空闲变得有事可干了,如上图红色部分。
这样做当然也不是免费的,代价就是更多的通信量,但最终端到端效果看是利大于弊
张量并行(TP)
当模型参数继续增大,除了用流水线并行之外,还可以用张量并行(Tensor Parallel)来缓解,就是以前常说的模型并行(Model Parallel)。
这里拿最基础的矩阵乘举例。对于激活值矩阵A和权重矩阵B相乘,考虑对B做列切分或者行切分,会有两种方案
- 方案1:按列切分。A不变,这样每一列独立计算,最后结果concat在一起
- 方案2:按行切分。A对应列切分,这样两部分独立计算,最后结果sum在一起
Megatron
在Megatron里面,作者详细分析了Transformer结构中每个部件的具体切分方案,如下所示
FFN-expansion矩阵乘
- FFN第一个expansion矩阵乘之后会接element-wise的GELU,所以
- 如果是方案2的话那么需要先做同步AllReduce然后再做GELU
- 如果是方案1的话,可以分开分别做GELU,推迟到很后面才做AllReduce
- 所以是方案1更好
FFN-contraction矩阵乘
- 经过FFN-expansion阶段后的输出此时新激活值的shape为(BL, 4D) ,并且使用了方案1对列4D做了切分
- 因此FFN-contraction阶段只需要对FFN-contraction的权重矩阵做行切分,就能满足上述方案2
- 最后做AllReduce Sum既可
f和g
这里的f和g是对偶的过程
- 在前馈过程中,f是identity,g是AllReduce
- 在反馈过程中,g是identity,f是AllReduce
MHA
- 每个head天然已经在第0维切开了,完全独立计算
- 跟FNN类似选择方案1,然后做完input proj和attention之后,对output_proj的权重矩阵做行切分
- 最后做AllReduce Sum既可
输入embedding
- 假设word embedding矩阵是E_{H \times v} ,其中H 是feature维度,v 是单词本大小
- 一般是E_{H\times v} X ,因此可以用方案2,即E=[E_1, E_2] ,X=[X_1;X_2]
- 对得到结果的AllReduce Sum即可
输出embedding
- 输出的时候每个token要跟单词本的向量做点积过softmax返回最大概率的词,即
\text{GEMM}[Y_1, Y_2] = [XE_1,XE_2]
- 其中X 是(BL, H) 尺寸, E_1 是(H, v_1) 尺寸,E_2 是(H, v_2) 尺寸
- 由于接下来要计算softmax,此时naive方法是做通信整个(BL, v) 的矩阵,但是由于v 很大,通信量很大
- 优化的策略就是各自求出softmax的指数分母的和,得到长度为BL 的和向量,对和向量做AllReduce,然后再计算softmax的值,这样通信量就从BLv 降低到了
通信量问题
宏观上看,每个Transformer结构内有两个带参数的模块,MHA和FFN。在张量并行模式下,这两个模块内的参数均匀分布在多个node上,模块之间需要AllReduce串行等待,每次forward和backward一共需要4次AllReduce。并且计算和通信没法并行。
因此相比于流水线并行中每个stage间才通信不同,张量并行每个Layer内都得通信,因此张量并行的通信量远大于流水线并行,一般在高速互联的同一个island内才使用张量并行
ZeRO
ZeRO的出发点是希望优化数据并行里显存占用。因为在数据并行里面,每个device上都有完整的权重信息,梯度信息和优化器状态信息,这个其实是比较冗余的。
GPT2显存占用分析
在具体介绍ZeRO的显存优化方案之前,我们先仔细分析一下如果用naive的训练策略,GPT2模型是如何在32G的单卡上,以32的batch size大小训练起来的。
已知基本的参数如下
- GPT2模型的参数量\Psi=1.5B ,序列长度L=1024 ,batch size大小是B=32 ,feature维度是D=768 ,一共12 个transformer block
- GPU的显存大小是32G
显存占用大致可以分成4部分,参数+梯度+优化器+激活值,大小为
2\Psi +2\Psi + K\Psi +7.2G= (2+2+14) \Psi + 7.2G=16 \Psi+7.2G=31.2G
详细数值拆解如下
- 参数parameters(fp16): 2\Psi=3GB
- 梯度gradients (fp16): 2\Psi=3GB
- 优化器optimizer states: 一般是K\Psi
,不同优化方式的系数K
不一样,例如adam的话是12\Psi
- parameters(fp32): 4\Psi=6GB
- momentum(fp32): 4\Psi=6GB
- variance(fp32): 4\Psi=6GB
- 激活值等residual states (约为7.2GB
)
- activation(fp16): 1个transformer block的激活值总量至少为12BLD ,那么12个为144BLD=3.6G 个激活值,大概7.2G显存
- temporary buffers: 包括了all_reduce的空间等
- fragmented memory: 内存碎片
显存放不下怎么办
假设模型继续增大,一张卡就放不下了,自然需要从权重和激活值两大方面砍预算,常见思路如下
- Paramater Server(PS),参数服务器思想,权重在整个集群集中存储一份,其他worker按需所取。
- Pipeline Parallelism(PP),按网络深度延伸方向切分网络成若干段,每一段的权重放在不同卡上,典型代表GPipe
- Tensor Parallelism(TP),把矩阵乘做矩阵分解,从而把权重矩阵均匀拆成若干份放在不同卡上,典型代表Megatron-LM
- re-materize技术。可以把显存降低到1/\sqrt{N} ,但是引入了额外一次forward,占用大概原始的33%的计算
- CPU-offloading。把激活值直接放到cpu上,BP的时候再加载回来
ZeRo借鉴了上述方法中的部分思路,分别就权重和激活值两个方面入手优化显存,分别称为ZeRO-DP和ZeRO-R
ZeRO-DP显存优化
ZeRO-DP的显存优化有三个层级,ZeRO-1,ZeRO-2和ZeRO-3
下面分析各种策略下的通信量和显存优化情况
策略1: naive DP的通信量
- 沿用前面数据并行的DistributedDataParallel,只在gradient更新的时候AllReduce一次,然后各卡上并行跑optimizer步骤
- 对gradient做AllReduce,使用ReduceScatter和AllGather方式,于是通信量是\Psi + \Psi = 2\Psi
- 显存方面没有任何优化,参数、梯度和优化器状态在各卡都有一份,显存占用16\Psi=120GB
策略2: Zero-1方案P_{os}
- P_{os} 对优化器状态optimizer state做partition
- 每张卡gradient算完之后用ReduceScatter发送到对应的卡上即可,通信量为\Psi
- 在各卡做完optimizer state更新权重之后,把权重做AllGather,通信量为\Psi
- 因此总的通信量为2\Psi ,跟naive DP一致
- 显存方面,参数和梯度跟naive DP一样,优化器状态降到了1/N_d=1/64 ,因此显存占用为4\Psi+12/64\Psi=4.1875\Psi=31.4GB ,约为策略1的1/4
策略3: ZeRo-2方案P_{os+g}
- P{os+g} P{os+g} 对optimizer state和gradient做partition,情况跟P{os} 类似,通信量也是\Psi + \Psi = 2\Psi ,因此跟naive DP一致
- 显存方面,参数跟naive DP一样,梯度和优化器状态都降到了1/N_d=1/64 ,因此显存占用为2\Psi+14/64\Psi=2.22\Psi=16.6GB ,约为naive DP的1/8
策略4: ZeRo-3方案P_{os+g+p}
- P_{os+g+p} 对optimizer state,gradient和weight都做了partition
- 前馈过程中,weight需要做一次AllGather才能算出各层的activation,通信量\Psi
- 反馈过程中,weight也要做一次AllGather才能算出各层的gradient,通信量\Psi
- 各层gradient需要ReduceScatter发送到对应的卡上,通信量\Psi
- 因此总的通信量为3\Psi ,为naive DP的1.5倍,增加50%通信量
- 显存方面,参数、梯度和优化器状态都降到1/N_d=1/64 ,因此显存占用为16/64\Psi=1.9GB ,约为naive DP的1/32
不同尺寸模型在不同DP并行度下,不同ZeRO策略的显存占用如下所示
ZeRO-R显存优化
R这里指Residual States,就是除了参数、梯度和优化器状态的其他显存占用,包括以下几方面
- activation。用已经切块之后的activation checkpointing技术,就是re-materialization技术
- temporary buffers。用一个固定大小的buffer
- fragmented memory。on-the-fly策略调配显存位置,消除碎片,具体来说
- 前馈过程: 保留的activation(long live)和删掉的activation(short live)交替出现而留下的碎片
- 反馈过程: 相对于weight的梯度(long live)和计算weight梯度需要的输入(short live)
DP/PP/TP小结
- DP的优势是计算和通信效率都很友好,但是权重的显存不友好,每张卡都有一份。因此ZeRO对此进行大刀阔斧优化,借鉴了Parameter Server的思路,ZeRO-3让每张卡只维护一部分的网络权重、梯度和优化器状态
- PP的问题是要求mini-batch里面batch size足够大才能掩盖住流水线带来的overhead。batch size如果过大,会增大激活显存的占用
- TP的优势是权重显存非常友好,没有冗余。但是计算和通信效率不友好,通信量要求很大,在超出了一个island的时候性能下降很快。
Alpa
前面说的方法都是人工设计的并行策略,而Alpa是一种自动化搜索并行策略的方法,很大程度上借鉴了AI编译器的思想。该文的一作就是TVM里面自动化搜索方法Ansor的一作。
Alpa主要针对计算图和硬件集群之间的分配问题,提出了Inter-op和Intra-op两阶段的自动并行策略搜索算法。这个算法整体是两层循环,先感性理解一下,如下图所示
- 外层循环(Inter-op):将计算图op序列切分成不同的stage,同时将硬件集群也划分为不同的mesh网格,形成若干个stage-mesh对
- 内层循环(Intra-op):对于每一种个stage-mesh对,用整数规划方法估计出计算和通信cost
- 运行编排(Runtime Orchestration):插入必要的within-mesh和cross-mesh的通信收发等节点,编译出完整可以跑的运行文件,得到完整端到端速度
问题定义
个人认为Alpa最大的贡献是对计算图和硬件集群分配问题做了数学建模,并且使用合理的简化和对应数学手段求解这个数学问题。下面我们先简单介绍一下计算图和硬件集群分配问题。
- 假设一个计算图有K 个结点,并且是个DAG,拓扑序依次为o_1 \rightarrow o_2 \rightarrow \cdots \rightarrow o_K
- 假设一个计算集群有 N \times M
个计算设备。一般来说每一行是一个高速互联的island,不同行之间的链接会稍微慢一点,例如
- GPU集群里面就是单机8卡,M=8
- TPU集群里面就是最大512个TPU组成的Pod,M=512
分布式计算本质上就是找到一个分配策略把K 个结点首先切分为S 个阶段,每个阶段包含计算图算子集合s_i ;同时把计算集群也切分为S 块,每一块大小为n_i\times m_i ,记为\text{Mesh}(n_i, m_i) 。然后把s_i 分配到\text{Mesh}(n_i, m_i) 进行计算。如下所示
这里注意\text{Mesh}(n_i, m_i) 被规定为n_i \times m_i 的矩形块,形状只有两种
- 一维情况: (1,1) 或者(1, 2) 或者(1, 4) 或者(1, 2^b) ,即高为1,宽为2的整数次幂
- 二维情况: (1, M) 或者(2,M) 或者(a, M) ,即高任意正整数,宽为固定M
可以证明对于任意形状的N \times M 计算集群,总能进一步切分为一系列{(n_1,m_1), \cdots, (n_S, m_S)} 计算网格,这些子网格是可以无重无漏地完整覆盖原始计算机群的。
因此计算图和硬件集群分配问题的求解过程,可以
- 先定义内层循环Intra-op的耗时接口为t_\text{intra}(s_i, \text{Mesh}(n_i, m_i))
- 然后外层循环Inter-op通过不停地枚举切分方式和调用Intra-op的耗时接口
- 最终找到一个端到端耗时最短的分配策略
下面我们就先详细了解一下内层循环Intra-op的细节,然后看外层循环Inter-op是如何做切分搜索的
Intra-op
Intra-op模块的输入是计算图s_i 和计算资源\text{Mesh}(n_i, m_i) ,输出的是最优分配策略和对应的实际耗时。由于我们使用的是SPMD的计算方式,因此\text{Mesh}(n_i, m_i) 内的所有device是对等peer关系,对每个op的执行都是一样的。因此我们可以转变为贪心算法,只需要依次确定s_i 里面的每个op的输入tensor是如何被切分发送到这n_i \times m_i 个device上进行计算,那整个s_i 和\text{Mesh}(n_i, m_i) 的分配策略就确定了。
这里岔个话题简单介绍分布式里常见的几个名词,SIMD、SIMT和SPMD
- SIMD(Single Instruction Multiple Data),就是我们常说的向量化指令集,一个指令可以处理多个数据。例如X86的SSE/AVX和ARM的NEO等
- SIMT(Single Instruction Multiple Data),一个线程块的多个thread线程执行同样的程序,最常见的就是CUDA编程了
- SPMD(Single Program Multiple Device),多个device上执行相同的程序,完全对等peer关系,然后用过collective communication的方式做信息交互
回到Intra-op的逐op切分过程。对于每个op,我们只需要关注输入tensor的切分选择既可,一般来说三选一,分别为Row-partitioned、Column-Partitioned和Replicated
拿matmul举例
- 数据并行情况下。输入是Row-partitioned,权重是Replicated
- 张量并行情况下。输入是Replicated,权重是Row-partitioned或者Column-partitioned
为了方便地描述tensor到device的切分方式,我们引入Sharding spec概念。类比于op的layout,例如conv2d有NHWC和NCHW两种layout,当前一个op输出layout不满足当前op输入layout要求的时候,我们做个layout转换既可。
举例2维tensor被切分到2x2的device mesh的情况,alpa定义了如下的9种Sharding spec
符号含义如下
- 字母个数跟tensor维度保持一致,字母含义只有两种,R代表Replicated,S代表Split。因此RR就是Replicated,SR就是按行切分,RS就是按列切分,SS就是行列都切分
- 为了描述device的映射,用上标数字代表被切分后映射到device mesh的哪个维度。因此只有S字母有上标数字,特殊的S^{01}R 代表把tensor沿着行均匀分成4分,同理得到RS^{01}
- 表格种每一列是原始tensor A被分配到各个device上的具体情况
有了上述的Sharding Spec,那么当出现不匹配的时候,引入Spec转换既可,即Resharding。但是注意这里跟layout 转换出现内存重排一样,Spec转换可能出现额外通信开销,如下所示
由上可知,每个op的每个tensor都有若干种Sharding spec,如果spec不匹配还会出现Resharding的额外通信开销。这个问题可以归约为整数规划问题,目标函数定义如下
\min_{s} {\sum_{v \in V}s_v^T(c_v+d_v)+\sum_{(v,u)\in E}s_v^TR_{vu}s_u }
其中
- V 可以看成所有待确定spec的tensor,v 代表其中任意一个tensor
- s_v 代表这个tensor所有的spec选择的one-hot向量,对应的c_v 和d_v 就是相应spec选择的通信代价和计算代价
- R_{vu} 就是Resharding代价矩阵,衡量相邻v 和u 的通信代价
接下来就是如何求解上述目标函数,本质上是个数学问题。这里有若干技巧
- 把二次项的矩阵R_{vu} 做flatten变成1维的,然后转化成跟第一项类似的线性one-hot选择,于是整数规划问题转化为了整数线性规划问题(ILP),有成熟的python package可以调用
- 仿照前人做法,将计算代价d_v
设置为0,原因是
- 对于计算密集型算子,不允许用replicated方式,因此SPMD下任何一种切分计算量都是一致的,因此可以置为0
- 对于非计算密集型算子,计算量本身可以忽略不计,因此也可以置为0
- 通信代价用通信量除以带宽近似计算
至此给定计算子图s_i 和计算网格\text{Mesh}(n_i, m_i) ,对应的最优耗时t_\text{intra}^{*}(s_i, \text{Mesh}(n_i, m_i)) 可以通过上述ILP过程求解得到
Inter-op
如何将原计算图切分为S 个子图s_i ,以及如何将\text{Mesh}(N, M) 的计算网格切分成S 个子计算网格\text{Mesh}(n_i, m_i) ,这是接下来Inter-op需要解决的问题。
切分计算图其实就是流水线并行的过程,我们考虑Gpipe或者synchronous 1F1B的流水线方式。假设网络被分成了4段,每一段分别放在不同的算力节点上,逻辑上串联形成流水线并行,如下所示
由于网络每一段的计算吞吐不完全一样,因此有快有慢。假设各段前向过程耗时为t_1 \sim t_4 ,那么容易知道整个mini-batch的总前向耗时为
t_\text{pipeline}=t_1+t_2+t_3+t_4 + (m-1)\cdot \max \{t_1, t_2, t_3, t_4\}
其中\max{{t_i}} 其实就是寻找最慢的stage。
泛化到一般情况,可以知道当切分为S 个子图后,整体端到端的耗时为
T^* = \min_{s_1,\cdots,s_S \atop (n_1,m_1),\cdots,(n_S,m_S)} \left\{\sum_{i=1}^{S}{t_i} + (B-1) \cdot \max_{1\le j \le S}\{t_j\}\right\}
其中t_i 是子图s_i 的执行时间,B是流水线并行中micro-batch的大小,在上图例子中B=8
上式求解非常困难,因为切分方式未知,并且还有个求最大值过程。仔细分析,如果我们假定所有stage最大值是t_\text{max} ,即第二项是已知的,那么只需重点考察第一个求和项了。因此可以把原问题转换为带最大值约束的最优化问题,如下所示
\begin{equation} T^* = \min_{\substack{s_1,\cdots,s_S \\ (n_1,m_1),\cdots,(n_S,m_S)}} \left\{\sum_{i=1}^{S}{t_i} + (B-1)\cdot t_\text{max} \right\} , \ \text{subject to} \ t_i \le t_\text{max} \end{equation}
为了进一步考虑子图和计算资源的划分,我们用DP动态规划求解,并且定义3维状态方程F(s,k,d;t_{max})
- 含义为把从o_k 到o_K 的计算图切分为s 个子图,并且放在d 个device上的最低耗时分配方案。
- 注意原始原图共K 个op,拓扑序为o_1 \rightarrow o_2 \rightarrow \cdots \rightarrow o_K
- 要求约束每个子图最大耗时为t_\text{max}
由F(s,k,d;t_{max}) 可以得到状态转移方程为
\begin{split}F(s,k,d;t_{max})= \min_{ k \le i \le K} \left\{t_{\text{intra}}((o_k,\cdots,o_i),\text{Mesh}(n_s,m_s),s) +F(s-1,i+1,d-n_s \cdot m_s;t_{max}) \atop| t_{intra} ((o_k,\cdots,o_i),\text{Mesh}(n_s,m_s),s) \le t_{max}\right\}\end{split}
方程比较复杂,含义如下
- 考虑把o_k \rightarrow \cdots \rightarrow o_K 从中间o_i 切开,前面o_k 到o_i 部分形成新的stage,并且放置到\text{Mesh}(n_s, m_s) ,这个耗时调用前面的Intra-op
- 剩下的子问题F(s-1,i+1,d-n_s\cdot m_s;t_{max})
- 注意这里的\text{Mesh}(n_s, m_s) 会穷举所有可能的值,只要n_s \cdot m_s <d
因此我们不难发现
\begin{equation}T^*(t_{max})=\min_{s}\{F(s,0,N\cdot M;t_{max})\}+(B-1)\cdot t_{max}\end{equation}
上述三维的DP过程复杂度很高,有些tricks可以降低复杂度
- t\text{max} 在从小到大的穷举过程中,如果B\cdot t\text{max} 已经大于当前的T^* 了,那么可以提前剪枝停止算法了
- 提前对原计算图中的非密集型算子(如ReLU)等做算子融合,可以极大减少搜索空间
Runtime Orchestration
为了的得到最终可运行的二进制文件,还需要做一些后处理,例如within-mesh和cross-mesh的通信算子。这里corss-mesh是指不同stage之间的通信,由于往往他们之间的通信带宽不高,因此需要用一定的优化策略,如下所示
红色边是跨mesh的较慢通信线路,绿色是mesh内的快速通信线路
对于mesh内的通信
- 如果前后两个mesh是相等尺寸的话,用(a)的scatter-gather方式
- 如果是不等尺寸的话,用(c)的all-gather方案
实验结果
alpa比Megatron-LM、Inter-op Only和Intra-op only策略都要好一点,但是跟最优秀的人工设计相比还是略差一点
最后还可视化了一下Wide-ResNet在16GPU上的搜索结果
写在最后
最后小结一下,本文主要有4部分内容
- 分布式通信原语包括了点对点通信和集合通信的方法。其中集合通信包括了一对多的Broadcast和Scatter,多对一的Reduce和Gather,多对多的AllReduce、AllGather、ReduceScatter和AllToAll
- 3D并行包括了数据并行(DP)、流水线并行(PP)和模型并行(TP)
- DP的优势是计算和通信效率都很友好,但是权重的显存不友好,每张卡都有一份
- PP的问题是要求mini-batch里面batch size足够大才能掩盖住流水线带来的overhead。batch size如果过大,会增大激活显存的占用
- TP的优势是权重显存非常友好,没有冗余。但是计算和通信效率不友好,通信量要求很大,在超出了一个island的时候性能下降很快。
- ZeRO针对数据并行显存占用大的问题,借鉴了Parameter Server的思路,提出了ZeRO-1,ZeRO-2和ZeRO-3的优化。其中ZeRO-2让每张卡只维护一部分的梯度和优化器状态,显存占用减少到原来的$1/8$,通信带宽保持不变
- Alpa鉴了AI编译器的思路,对3D并行进行建模,用自动化搜索的方式得到了仅次于手工最优的并行策略
PS:由于笔者小A并没有亲手撸过上述内容的所有细节,大部分是通过研究代码和精读优秀文章的方式bottom-up总结而来,本质上是个拾人牙慧的知识搬运工,所以终究是纸上谈兵。因此希望各方有实际经验的大佬猛锤,思维碰撞才生火花,真理越辩越明。
如果想了解transformer在NLP/多模态/AIGC的算法知识,分布式训练的知识,以及如何在TVM上做PTQ量化和部署,可以关注我aaronxic哟~
系列文章导览
[Transformer 101系列] Perplexity指标究竟是什么?
[Transformer 101系列] ChatBot是怎么炼成的?
[Transformer 101系列] AIGC组成原理(上)
参考资料
大语言模型(LLM)分布式训练框架总结_PaperWeekly的博客-CSDN博客
NLP(十二):DeepSpeed Inference 在 LLM 推理上的优化探究
DeepSpeed ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率
大语言模型(LLM)分布式训练框架总结_PaperWeekly的博客-CSDN博客
Collective Operations — NCCL 2.18.1 documentation
Distributed data parallel training using Pytorch on AWS
Pipeline-Parallelism: Distributed Training via Model Partitioning
An Overview of Pipeline Parallelism and its Research Progress
大模型训练 Pipeline Parallel 流水并行性能分析