Skip to content

[論文解讀] Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting

前言

這是華為諾亞方舟實驗室所提出加速框架,本質上是把原先投機解碼(speculative decoding)中所使用的小模型由大模型的淺層網路取代,並再由額外訓練的適配器(adapter)加上模型本身的解碼頭去生成推測的 token,再由大模型去進行驗證 —— 後續的操作與原先的 speculative decoding 其實沒有太大差別。


背景介紹

為了介紹 Kangaroo 的加速方法,首先得先介紹投機解碼speculative decoding)這種技術。

投機解碼是個被驗證用於加速模型推理非常有用的技巧,其基本概念是通過量級較小並且推理速度較快的草稿模型(drafter)來快速生成一些推測性的解碼結果,這些解碼的結果再交由大模型進行檢驗。

比方說草稿模型可能會一口氣生成 "Today"、"is"、"a"、"nice"、"good" 這樣一個系列的解碼結果,接著再由大模型一口氣預測以下結果:

  • "" -> "Today" (O)
  • "Today" -> "is" (O)
  • "Today is" -> "a" (O)
  • "Today is a" -> "nice" (O)
  • "Today is a nice" -> "day" (X):原本草稿模型預測出 "good"

這 5 個步驟都是一口氣被大模型平行解碼完成的,不過中間如果某個由草稿模型預測的解碼結果被回絕了,則會更改為大模型預測出的解碼結果,再交由草稿模型繼續往下預測。
這樣一來,本來單獨由大模型依序完成 5 次解碼的時間,就變更為一次推理產生結果了。

不過由於草稿模型與大模型平行解碼的緣故,其 GPU VRAM 的開銷會比單獨大模型解碼來得更高,就某種程度上算是一種利用空間換取時間的策略。

而說明完了投機解碼的概念之後,接著就可以進入正題介紹 Kangaroo 的作法了。


Kangaroo 作法

一致 Token 接受率(Consistent Token Acceptance Rate, CTAR)

Kangaroo 提出的架構很有趣,但在看它的架構之前,首先來談談研究者所提出的新評估指標。

一般在投機解碼時通常使用兩個指標:『實際時間加速率』(wall-time speedup ratio)和『壓縮率』(compression rate)。但 Kagaroo 的研究團隊指出,這些指標並不能反映出在不同上下文情境中草稿模型 token 接受率。

Compression\ Rate\ (CR) = \frac{Accepted\ Draft\ Tokens}{Total\ Tokens}


研究團隊引入一個新的評估指標『一致 Token 接受率』(Consistent Token Acceptance Rate, CTAR),用於評估在給定前綴(prompt)和後續窗口尺寸的情況下,草稿模型推測的 token 全部被目標模型接受的概率。

Consistent\ Token\ Acceptance\ Rate\ (CTAR) = \frac{Accepted\ Draft\ Windows}{Total\ Windows}


直覺上能很清楚地想像,隨著窗口尺寸增大,CTAR的分數想必會開始遞減。


提早退出機制

在 Kangaroo 中所使用的草稿模型,並非如原本 Speculative Decoding 方案那樣使用另外一個完整的小模型作為草稿模型,而是把目標模型(大模型)的淺層網路(Shallow Sub-Network)作為草稿模型的一部分,再額外添加適配器網路(Adapter-Network)以及拿模型最後一層的解碼層(LM Head)組成草稿模型,所以可以大幅地減少草稿模型的參數量。

實際上也確實只有加入適配器網路的參數量而已。

接著就是 Kangaroo 架構所提出的提早退出機制了。提早退出機制的意思是,當草稿模型對於當前 token 的預測信心值低於某個閾值時,便會提早停止生成,而把當前預測的所有推測 tokens 交由目標模型去評估,也變相地把草稿模型沒把握的困難解碼交給目標模型去處理。

可以理解成如果草稿模型都對自己的生成沒信心了,那不如早點交給目標模型去解碼,反正就算草稿模型自己花費計算時間去解碼,目標模型驗證時的拒絕機率也是很高。

詳細的資料流可以從上方的架構圖中確認。值得一提的是,若要訓練 Kangaroo 架構的模型,其訓練參數僅有適配器網路(通常只包含兩個正規化層和一個多頭注意力機制)而已,可說是低成本的微調任務。


實驗結果

在實驗中,Kangaroo 在 Spec-Bench 上顯示了顯著的加速效果,達到了 1.68 倍的加速,並且所需要的額外參數量比其他方法少了 88.7%。


結尾

由於 Kangaroo 的原始碼有開源在 GitHub 上(請參考下方參考連結),我在很早期的階段就去拜讀過了;不過一直到今天(2024-06-03)依然沒有釋放出訓練的腳本。

當然,由於 Kangaroo 的架構已經寫好了,直接拿這個架構來訓練也不是不行,說實話我是挺想自己測試一遍的,畢竟我今年自己設定的目標就是加速推理方面的研究。

不過在那之前,我會想先把 Medusa 的實作先完成吧!那邊開了個頭,後來又因為工作忙碌暫時擱置了。

不過加速推理技術很仰賴現行的框架,因為只是改進架構本身是沒法應用在產品上的,終究還是得使用上優化整合深入的框架才行。

所以,閱讀各種加速推理框架的原始碼也是課題之一吧!


References


Read More

Leave a Reply