【元來如此】第三章——Mixtral 8x7B深入挖掘,改變游戲規則的AI模型!


正文共:2285字 8圖
預計閱讀時間:6分鐘
作者:思成
軟件生態中心·應用平臺部·模型應用組
Mixtral 8x7B是由Mistral AI在23年10月發布的一種稀疏混合專家混合模型(SMoE),在多項Benchmark測試中效果優于LLaMa2 70B和GPT3.5。基于Mixtral 8x7B得到的Mixtral 8x7B-Instruct指令跟隨模型在Human Evaluation榜單上超過了GPT-3.5 Turbo、Claude-2.1、Gemini Pro和LLaMA 2-Chat 70B。同時,這2個模型免費供學術和商業使用。
Mixtral 8x7B同樣是一個Decoder Only的模型,區別于傳統的LLaMA等模型,FNN層由8個前饋神經網絡(Expert)組成。如果我們從某一個Token的視角看,這個Token會經過其中的2個前饋神經網絡或者說Expert。也就是說雖然整個模型的參數量是46B,但是在推理過程中激活的參數只有13B。
剛剛提到對于1個Token只有2個Expert被激活,這個機制是通過路由網絡(Router)來控制。在Mixtral 8x7B中路由網絡由1個前饋神經網絡組成。整個混合專家層模型結構以及數學表示如圖1[1]所示。


圖1 混合專家層
雖然Mixtral 8x7B在推理過程中同一時間激活的參數只有13B左右,但是為了保證推理性能,還是需要將全部參數(46B)讀入顯存,以A100-80GB為例,對于46B的參數的模型,按照FP16精度來估算,參數預期占用92GB顯存。以batch為20、Context length=1024、Generate length=1024來看、KV Cache需要的顯存為20GB[2] ,也就是說理想態下(不考慮顯存碎片,Activate output),至少需要2張A100-80GB完成推理。
多卡的LLM推理場景下,我們將LLM模型結構抽象為Attention和FFN兩部分,因為Mixtral 8x7B模型不涉及Attention部分的模型結構改進,所以Attention部分的并行方案[3]遵照標準方案,如圖2所示。

圖2 Attention層模型并行方案
接下來我們來看混合專家層的并行方案設計。

數據并行
Mixtral 8x7B 2張卡的數據并行方案如圖3所示。可以看到Expert 1-8的參數在每張GPU上都被拷貝了一份,對于不同的輸入,每張卡單獨進行推理。這樣的方案不需要在Expert之間進行額外的通信開銷。但是顯而易見的問題是顯存開銷大;同時,每一張卡的輸入是一個完整的任務,需要在每張卡任務的Batch和序列緯度進行MASK,從而拆解為不同的子任務,并單獨完成模型推理,最后通過All Gather通信再次在每張卡上還原為完整的任務。
從上述過程中,可以看到算力除了特殊情況(Batch=1, Sequence length=1)之外并沒有浪費,同時在保證Batch * Sequence length % GPU_CNT== 0的情況下也不存在負載均衡問題。

圖3 Mixtral 8x7B 2卡并行推理-數據并行方案

專家并行
從數據并行的方案中可以看到,最大的問題是因為每張卡都需要保留所有的Expert帶來的顯存浪費。既然這樣,為什么不每張卡只保留若干Expert呢?專家并行就這樣被提出了。如圖4所示。
可以看到Expert 1-4被分配到了GPU 1上,Expert 5-8被分配到了GPU 2上,很好的解決了顯存開銷問題;但是同時引入了All2All通信(圖4中紅線)。和數據并行方案類似,每一張卡的輸入是一個完整的任務,需要在每張卡任務的Batch和序列緯度進行MASK,從而拆解為不同的子任務,每個GPU上的子任務將各自的任務發送到對應的Expert上完成推理,之后再發送回自身所屬的GPU,最后通過All Gather通信再次在每張卡上還原為完整的任務。

圖4 Mixtral 8x7B 2卡并行推理-專家并行方案
另外,圖4中的例子是一個理想態,Expert的激活剛好是每張卡激活2個Expert。但因為每個Token實際激活哪個Expert是不確定的,考慮更一般的情況如圖5所示。可以看到這時候GPU 1上激活了3個Expert,GPU 2上激活了1個Expert,也就是說專家并行是存在負載不均衡的。另外考慮Batch=1的推理場景,Generate階段同一時間只會有1張GPU卡被激活,造成算力的浪費。

圖5 Mixtral 8x7B 2卡并行推理-專家并行方案
為了進一步驗證Mixtral 8x7B的負載均衡問題,我們隨機挑選了若干問題,并分別統計了32層的Expert激活情況,如圖6(Heatmap)所示。

圖6 Mixtral 8x7B 不同層Expert激活情況
從負載均衡角度,全局256個專家中可以明顯看到Layer 14 的Expert 3 是最經常被路由到的,而Layer 13的Expert 4是最少被路由到的,前者比后者要勤奮5.72倍。從每層的8個Expert橫向對比也可以發現,Layer 8 中最勤奮的Expert比最懶惰的Expert多工作4.3倍。


模型并行
為了解決數據并行和專家并行帶來的若干問題,我們討論模型并行的方案,如圖7所示。將Expert按照Embedding維度縱向切分。每張卡保留所有Expert的一個切片。

圖7 Mixtral 8x7B 2卡并行推理-模型并行方案
GPU 1上保留了Expert 1-8的部分參數(左側實線),GPU 2上保留了Expert 1-8的另外一部分參數(右側實線)。同時在輸入側不需要按照Rank進行MASK操作,每張卡處理全部的數據,并根據各自的參數切片得到輸出的切片,最后通過All Reduce通信將輸出融合。
從上述的過程中,可以看到,模型并行的方案不存在負載均衡問題,同時受到的約束相對來說最少(Dimension % Rank == 0),同時不會引入額外的通信開銷。
綜上,我們發現基于稀疏混合專家混合模型的LLM推理主要限制因素包括顯存、通信、計算3方面。而數據并行的方式在顯存方面存在巨大的劣勢,模型并行和專家并行在顯存方面的顯存占用一致。所以問題進一步簡化為通信和計算兩方面。


太初元碁Tecorigin基于上述分析,深度優化了Mixtral 8x7B模型推理,在Batch size=1,Context length=1024,Generate length=1024下,端到端推理速度百分位(越大越好)對比基于GPU A800 * 2硬件的不同開源LLM推理框架效果如圖8所示。

圖8
Mixtral 8x7B 2卡并行推理TecoInference端到端速度對比
至此,本文簡要介紹了Tecorigin在Mixtral 8x7B模型推理上的探索。未來,期待更多的大模型技術跟大家一起分享、交流、討論。

參考文獻
[1] [2401.04088] Mixtral of Experts (arxiv.org)
[2] 【元來如此】第二章——打破序列長度限制,讓無限Token成為可能!(qq.com)
[3] [2104.04473] Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (arxiv.org)