Last Updated on 2025-04-16 by Clay
目前 LLM 的推理時,大部分的時間都卡在需要『逐一生成 Token』的這一環節,這顯示了當前 GPU 記憶體的瓶頸 —— 我們每次讓模型解碼出一個 Token,就必須要讀取整個模型的權重,而實際的浮點數運算量相對較小,導致 GPU 的計算能力未能充分發揮。
所以推測性解碼(Speculative Decoding)的解決方法被提了出來。它利用一個較小、推理速度較快但性能沒那麼好的 draft model 提出多個接下來的 Tokens 候選,比方說輸入『Hi』後,draft model 自動生成了『, what's your name?』,並由我們希望加速的 target model 一次性進行驗證(我們可以把 Hi, what's your name?
當作輸入,取得 target model 對於每一個 Token 的預測結果 )。最後,我們只要選擇驗證通過的部份、截斷不接受的部份,就可以在不損失性能的情況下加速推理。
而 draft model 的選擇是很重要的:它不僅需要『快』、同時也需要『準確』,這才能在驗證時有較高的接受率。
如經典的 Medusa([2401.10774] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads),就是把 draft model 設計成一系列位於 target model 解碼層前的輕量級解碼頭(Decode Heads) ,每一個解碼頭負責解碼不同時間序的生成結果。
Hydra Heads

本篇研究提出的 Hydra Heads 與 Medusa Heads 相似,同樣是每一個 Heads 負責解碼不同時間點的生成結果;但是研究團隊更是引入了每一個時間點之前的、前一個 Head 所解碼出的 Token 資訊,進一步提昇了 Hydra Heads 的接受率。
可以想像成,本來的 Medusa 是在看不到前一個 Head 生成的確切結果、只看得到相對模糊的 hidden states 進行預測的、但是 Hydra 會明確告訴準備進行生成的 Head,前一個生成的 Token 是什麼。
原始論文寫的是 而非
,我稍微換了一下符號與 Token 的
區分。
閱讀原始碼,我們可以看到一個經典的 Hydra Head (MLP):
self.hydra_head = HydraMLP(
hydra_num_layers=self.hydra_num_layers,
hydra_num_heads=self.hydra,
grounded_heads=self.grounded_heads,
input_embed_fn=self.base_model.model.embed_tokens,
base_config=self.config,
lm_head_init_weight=base_model.lm_head.weight.data
)
self.hydra_lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
其中除了要指定層數與頭數外,也要把 input_embed_fn
放進去 —— 一般來說就是放 model 的 embedding layer。
在 HydraMLP 的前向 forward()
中,我們可以很清晰地 Hydra Heads 的運作方式:
def forward(self, base_hidden_states, input_ids=None, noise=None):
"""
Forward pass of the MLP.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after the MLP.
"""
hydra_hidden_states = []
if self.grounded_heads:
assert input_ids is not None, "Input ids must be provided for grounded heads"
with torch.inference_mode():
input_embeds = self.input_embed_fn(input_ids)
if noise is not None:
input_embeds = input_embeds + noise
hydra_inputs = [base_hidden_states]
for i in range(self.hydra_num_heads):
# Move input embeddings back one spot for each hydra head idx
hydra_inputs.append(torch.roll(input_embeds, shifts=-(i+1), dims=1))
for i in range(self.hydra_num_heads):
head_input = torch.cat(hydra_inputs[:i + 2], dim=-1)
hydra_hidden_states.append(self.hydra_mlp[i](head_input))
else:
for i in range(self.hydra_num_heads):
hydra_hidden_states.append(self.hydra_mlp[i](base_hidden_states))
也可以很明確地看到 Token 轉換成 Embeddings 的過程:
with torch.inference_mode():
input_embeds = self.input_embed_fn(input_ids)
研究團隊進行實驗,表明這比 Medusa 快了 1.1 倍。
Hydra++
Hydra++ 是研究團隊針對 Hydra 架構提出來的訓練方案,探討了訓練目標和架構,將吞吐量(throughput)提昇到 Medusa 的 1.31 倍、原始架構的 2.70 倍。
其中有三個有效的改進(參考 Appendix A):
- 擴展(Scaling):將每個 Head 的 MLP 擴展到了 4 層 —— 實驗發現超過 5 層沒有效益。
- 蒸餾(Distillation):採用自蒸餾(self-distillation),讓 Hydra Heads 學習預測 target model 對於給定 Token 的輸出分佈、而非真實的下一個 Token。
- 前綴注意力)(Prefix Attention):為了讓 draft model 的部份更好地利用上下文(context)的資訊,在 target model 外添加一個自注意力解碼層(self-attention decoder layer)。這一層在每個解碼步驟中只被查詢一次,提供更具有資訊的隱藏狀態輸入。
樹狀解碼(Tree Decoding)
跟 Medusa 一致,Hydra 解碼也是使用事先定義好的靜態樹拓樸(static tree topology);但隨之而來的問題是,該怎麼找到最佳的拓樸結構?
研究團隊採用了一種迭代的貪婪解法,從只有一個節點的樹開始,在每一個步驟找出『在哪個現有節點底下、增加子節點能提高預期接受長度』—— 依此類推直到結束。
訓練結果
模型: Base Model 使用 Vicuna 系列(7B, 13B, 33B 參數),這些是基於 LLaMa 微調的對話模型。
訓練:
- 只訓練 Draft Heads,Base Model 的權重保持不變(frozen)
- 訓練數據集: ShareGPT (多輪對話數據)
- 硬體: 8 塊 NVIDIA A100-80GB GPU
- 框架: Hugging Face Trainer
- 優化器: AdamW (
)
- 學習率: 使用帶有預熱(warmup)的餘弦學習率調度(cosine learning rate schedule),峰值學習率為 1e-3。
- 訓練輪數 (Epochs): Hydra 和 Medusa heads 訓練 1 輪(發現已飽和),Hydra++ heads 訓練 10 輪。

References
- arXiv: Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding
- https://github.com/zankner/Hydra