Last Updated on 2025-03-25 by Clay
最近依然還是在看加速推理的東西,奈何手邊一直在忙工作的事情沒來得及發出來呢;今天要介紹的加速推理架構是稱為 Medusa 的經典多頭解碼模型。
梅杜莎(Medusa)是希臘神話中的經典角色,也有一些翻譯是稱為蛇髮美人、其每一根髮絲都是一隻小蛇 —— 而本篇論文所設計的架構也是呼應這個形象,有著多個解碼頭(Decoding Heads)。
現在的大型語言模型多是採用自迴歸解碼(auto-regressive decoding),每個當前的解碼結果都依賴於前一步驟的輸出,這也導致了計算瓶頸。
現有的方法中,推測性解碼(Speculative Decoding)能夠有效加速推理,但需要維護額外的草稿模型(draft model),管理與佈署都增加了複雜度。(可以參考我之前寫過的筆記:[論文閱讀] Fast Inference from Transformers via Speculative Decoding 和 [論文閱讀] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding)
而 Medusa 架構在採用了不同解碼頭預測不同時間序的 Tokens 外,也提出了使用樹狀注意力機制(tree-based attention)來同時驗證多個候選結果。基本上,Medusa 與多數 Speculative Decoding 的方法不同,無法保證加速後的輸出分佈與原始 LLM 一致,所以需要額外做實驗印證模型的性能沒有損失。
本研究提出了兩種訓練方案:
- Medusa-1: 在固定的基礎模型上微調 Medusa Heads,不影響生成品質,提供 2.2 倍的加速。
- Medusa-2: 與基礎模型一起微調,能提高 Medusa 預測準確性,提供 2.3 倍到 2.8 倍的加速,速度更快、但同時需要更多訓練。
Medusa 架構
Medusa 架構中最重要的有兩個:一個是多個解碼頭、一個則是樹狀注意力機制。
多重解碼頭
解碼頭的架構如下:

在實作中,如果怕把 decoding layer 重新訓練的話會有太多參數(head_num * last_hidden_size * vocab_size),更簡單一點的實作是共享最後一層 decoding layer,把 Medusa Heads 宣告為多層線性層,如此一來參數量只會有 head_num * last_hidden_size * last_hidden_size * layer_num,而通常 vocab_size 遠大於 last_hidden_size。
樹狀注意力機制
樹狀注意力機制是用於驗證多個不同的解碼結果的。下圖中打勾處,代表著模型能夠看到的部份;其餘沒有打勾處,則是被 Attention Mask 掉的部份。
下圖總共有 6 個候選路徑:
- Head 1: ["It", "I"]
- Head 2: ["is", "'", "the"]
2 * 3 = 6,我們可以組合出:
- It is
- It'
- It the
- I is
- I'
- I the

之後,Medusa 並不是採用 Speculative Decoding 的 Reject Sampling,而是採用經典接受方案(Typical Acceptance Scheme)。
核心概念是不要求預測的 Tokens 與原始模型的分佈完全相同,而是選擇最典型(Typical)的序列來接受;另外,也並不是單純設定一個硬閾值(hard threshold)來判斷是否接受 Token,而是另外引入了熵(entropy)來動態控制閾值,在高熵時允許更寬容的 Token 多樣性、在低熵時更嚴格地選擇高機率 Token。
這樣講很模糊,直接來看公式。
假設我們定義序列 ,我們透過以下的條件決定是否要接受候選 Token:
是原始模型對當前 Token 的機率分佈
是 Shannon 熵函數,用來測量模型對當前狀態的不確定性,實際上計算為
是硬閾值(hard threshold),可以設定為 0.01、0.05… 等等
是與熵相關的可變係數,可取 0.5, 0.7, … 等等
當然,我們可能會接受很多條候選序列,此時為了考慮加速效果,選擇最長的序列。
訓練方案 1: Medusa-1 Frozen Backbone
最經典的一種是只針對額外多加的幾個 Medusa Heads 進行訓練,而把原始模型架構凍結。
訓練的資料可以先拿任何對話資料集,但比較好的方法應該還是讓原始模型針對問題集生出自己的回應讓 Medusa Heads 學習 —— 這樣才能確保 Heads 學習到的是原始模型真正會輸出的分佈,提高接受律。
而 loss function 可以簡單地使用 cross entropy,並根據不同位置的 Heads 調降權重;這個概念是讓預測時間序越往後的 Heads 可以不用強迫自己學習。
其中 是一個 k 次方的衰減係數,論文中定為 0.8,讓越往後的 Heads 的 loss 越低。
訓練方案 2: Medusa-2 Joint Trainin
Combine Loss
為了保持原始模型 Next Token Prediction 的性能,我們需要把 Cross Entropy 加入 Medusa loss 中。
整體的 loss function 看起來會是:
不同的學習率
這一小節沒有特別說明該如何配置,但提出了本來的原始模型是訓練好的、而 Medusa Heads 是全新初始化的,直覺上可以替兩邊配置不同的 Optimizer 以及不同的 Learning Rate。
Heads Warmup
這邊提出可以先只訓練 Medusa Heads,保持在 Medusa-1 的方案作為『預熱』;然後在進行 Medusa-2 的訓練方案 —— 讓原始模型跟著 Medusa Heads 一起微調。
自蒸餾機制說明(Self-Distillation)
在 Medusa 架構中,訓練解碼頭(Medusa Decoding Heads)需要一個與原始模型輸出分佈相符合的資料集。然而在實際情況中,有些模型的訓練資料可能並未公開,或者模型已經透過透訓練方式,比如 RLHF 進行微調 —— 這種情況造成原始訓練資料的分佈與模型的輸出分佈不同。
這時候,可以採用『自蒸餾』(Self-Distillation)的方式產生適合 Medusa 訓練的資料。
透過模型來產生訓練資料集的方式非常直觀,我們先從公開資料集中選取所謂的種子資料集(Seed Dataset),並把提示(Prompts)當作輸入送入原始模型,並取得模型生成的回應(Response),以此組成新的資料集 —— 我們接下來就讓 Medusa Heads 在這份資料集上面訓練即可。
不過在 Medusa-1 的訓練方案上(凍結原始模型,只訓練 Medusa Heads)可以這種做,但是在 Medusa-2 的訓練方案中(同時微調原始模型和 Medusa Heads),直接使用自蒸餾資料集的硬標籤(hard label)合成資料訓練原始模型會導致性能下降。
這裡說明一下我的個人理解:我本來想說,拿自己生成的資料訓練自己,不就是保持一致的性能嗎?怎地說效果會降低呢?後來轉念一想,本來模型的輸出是帶有更多的資訊,在不同溫度、採樣方式下會生成不同的句子,現在則是使用 hard label 的方式訓練模型看到特定的輸入時,就一定聚焦於自己的某種偏好輸出 —— 久而久之恐怕模型的行為確實會變死板,也造成所謂的『效果下降』。
為了解決這個問題,在 Medusa-2 時的 loss function 調整為直接拿原始模型的機率分佈而非確定的 token 作為蒸餾目標,這個方法就近乎傳統的知識蒸餾(Knowledge Distillation, KD):
實驗結果
最後放一下 Medusa 的研究團隊實驗結果。
在 Vicuna 7B 和 13B 兩個量級上,分別取得了,2.18x~2.33x(Medusa-1)和 2.83x(Medusa-2)的提昇。

雖然效果不錯,但我內心還一直對不完全符合原始模型的機率分佈輸出有些牴觸,只能期待一下之後看的 Eagle 論文,是怎麼說明應用了樹狀注意力機制並保持與原始模型同樣的分佈輸出的。