Skip to content

[論文閱讀] Fast Inference from Transformers via Speculative Decoding

Last Updated on 2024-11-03 by Clay

Abstract - 摘要

在自迴歸模型(Auto-regressive Model)解碼時,如果需要解碼 K 個詞元(Tokens),則需要跑 K 次流程,而這正是當前大型語言模型的推理時間瓶頸所在。

本篇閱讀的論文研究提出了推測性解碼(Speculative Decoding) 一種利用平行化計算、無損解碼準確性的加速推理演算法。


Introduction

一個核心的觀察點是:在模型的解碼過程中,並不總是困難的任務。比方說,在生成專業論述時,其知識表達的精準度非常重要且困難;然而,日常跟別人打招呼時卻可以輕鬆到甚至可以無視文法。

以上舉例的兩種情境難度自然天差地別。

而研究者希望可以找到一種不用重新訓練模型、不改變既有模型輸出機率分佈的方法,並且這種方法啟發自 Speculative Execution投機化執行)。

Speculative execution
Speculative execution is an optimization technique where a computer system performs some task that may not be needed. Work is done before it is known whether it is actually needed, so as to prevent a delay that would have to be incurred by doing the work after it is known that it is needed. If it turns out the work was not needed after all, most changes made by the work are reverted and the results are ignored.
引用自維基百科: https://en.wikipedia.org/wiki/Speculative_execution

在 Transformer 的解碼任務上,研究者們設想使用一些近似模型(approximation models) 去進為速度較慢的目標模型(target models) 進行前期的投機採樣(speculative sampling),最後再藉由讓目標模型推理下一個 Token,讓過往的投機採樣全部都一併經過驗證。

在上方的圖例中,綠色的 Token 是由 approximation model 產生並被 target model 接受的 Token;而紅色的 Token 則是被 target model 拒絕的 Token,並且會由 target model 自行生成正確的 Token - 也就是藍色的 Token。

而所謂的 rejected suggestions(紅色 Token),即是與 target model 在該 Token 位置的機率分佈中不會選取的答案,所以被拒絕了;相對的,accepted token 則是跟 target model 在該 Token 位置的機率分佈最高相似,故而有機為被接受(下一章會仔細說明驗證的機制與數學,說明為什麼 Speculatve Decoding 可以稱為『無損加速』)。

這樣一來,就能保證利用 Speculative Decoding 的方式,不會讓 target model 的輸出產生變化,而是在 approximation model 推理速度較快的情況下保證加速。

根據實驗結果,在輕量模型與大模型的搭配下,提昇速度約為 2x - 3x


Speculative Decoding

基本的方法如上一節所述,由於 target model 會同時評估所有的投機採樣,所以在最差的情況下,target model 從第一個 Token 就拒絕了,並且改用 target model 的解碼結果(在評估時會同時得到,不用重新計算),但在 approximation model 解碼速度夠快的情況下,其實時間損失並不嚴重。

在最順利的情況下,target model 會接收前面所生成的 t 個 Tokens,並且會因為評估時本身就是在推理第 t+1 個 Token,所以總共會利用 1 單位的推理時間,順利得到 t+1 個 Tokens。

而採樣時,其實還有著許多不同的參數可以調整:比方說 argmax、top-k、nucleus(top-p)和溫度設定會在 logits level 上影響,最終依然可以回歸到標準的機率分佈採樣模式,所以無論使用哪種參數設定,都可以使用 Speculative Decoding 方法的加速。

以下我們實際來看 target model 對於 approximation model 的推測解碼,到底是用什麼機制來判度接受與否。

在這裡,我並不會使用原始論文中的符號,而會嘗試使用自己的符號替代,力求更簡潔清楚地說明。

假設我們有一個草稿模型(draft model)的對於某個特定位置解碼 Token 的條件機率 P_{draft}(y_{t}|y_{<t}) 以及目標模型(target model)同樣位置解碼 Token 的條件機率 P_{target}(y_{t}|y_{<t}),以下,我將其簡寫為 P_{draft}(x_{t})P_{target}(x_{t})

當我們利用 draft model 生成一個 Token 時,我們會使用 target model 進行驗證,再決定是否接受。實作上,我們會計算出一個 u 代表來自均勻分佈 U(0,1) 的隨機變量。

而在以下的情況,我們會決定拒絕使用 draft model 的推測解碼:

\frac{P_{target}(x_{t})}{P_{draft}(x_{t})} < u

這會有兩種情況:

  1. P_{target}(x_{t}) > P_{draft}(x_{t}): 在這種情況比值一定大於 1,意味著這個由 draft model 推測的 Token 在 target model 上只會機率更高,應該被接受
  2. P_{target}(x_{t}) \le P_{draft}(x_{t}): 在這種情況,draft model 的機率分佈與 target model 並不相似,需要有一定機率將其拒絕

而這個接受的機率為 \frac{P_{target}(x_{t})}{P_{draft}(x_{t})},拒絕的機率則為 1-\frac{P_{target}(x_{t})}{P_{draft}(x_{t})}

之所以採用這樣的計算方式,是因為我們希望此 Token 的採樣分佈,與我們直接使用 target model 是一致的、無任何損失的。

以下是論文中的證明:

對任何的 x_{t},它的機率可以寫成:

P(x=x_{t})=P(accepted, x=x_{t})+P(rejected, x=x_{t})

那在由 draft model 解碼後並被接受的可能性,則是可以展開寫成 draft model 解碼 token 的機率乘以驗證機制的接受機率:

P(accepted, x=x_{t})=P_{draft}(x_{t})min(1,\frac{P_{target}(x_{t})}{P_{draft}(x_{t})})=min(P_{draft}(x_{t}),P_{target}(x_{t}))

而拒絕的機率表示,我們需要額外引入一個 \beta 符號表示接受的機率(這部份請看原始論文的 Appendix A.1),也即是說 1-\beta 則為拒絕的機率;而 $latex P'{target}(x{t})$ 則為我們調整後由 target model 解碼的機率,也即是上面我們所列的 \frac{P_{target}(x_{t})}{P_{draft}(x_{t})}

所以拒絕的機率表示,就可以寫成:

$latex P(rejected, x=x_{t})=(1-\beta)P'{target}(x{t})=P_{target}(x_{t})-min(P_{draft}(x_{t}),P_{target}(x_{t}))$

接著將 accepted 和 rejected 的機率表示相加,就會得到:

P(accepted, x=x_{t}) + P(rejected, x=x_{t})=P_{target}(x_{t})-min(P_{draft}(x_{t}),P_{target}(x_{t}))+min(P_{draft}(x_{t}),P_{target}(x_{t}))=P_target(x_{t})

所以在我們把調整機率列為 \frac{P_{target}(x_{t})}{P_{draft}(x_{t})} 後,可以保證等價於原本的 target model 對於此 Token 解碼的機率分佈。

通俗一點地直覺來想,我們本來 target model 解碼 x_{t} 的機率假設是 0.3,而現在 draft model 解碼 x_{t} 的機率是 0.8,那麼如果我們現在要拒絕此 token 的話,因為前面草稿模型已經抽樣過一次了,所以我們直接拿 1-0.3=0.7 的機率去拒絕它,其實就會不符合我們本來 target model 的機率分佈(第一次在草稿模型的抽樣已經影響到我們的抽樣機率了),現在我們需要校正機率值,公平地讓原始 target model 的機率分佈 0.3 為此 Token 的整體解碼機率,所以我們反過來調整一下接受解碼的機率 P',也就是 0.8 * P' = 0.3,那麼 P'=\frac{0.3}{0.8}=0.375 就是一件再自然不過的事情了。

以下是論文中列出的演算法,根據此演算法,應能正確實作 Speculative Decoding。


總結

論文的後續大部份是提出其實驗結果與他們改進實際時間(wall-time improvement)的計算方式,仍然非常值得一讀,但受限於篇幅,這裡就不展開來細說。

但綜上所述,此篇論文對於大型語言模型的加速推理無疑是相當具有啟發性的,因為 Speculative Decoding 是真正具備可用性與實際價值的方法,後續也有著許多工作是基於此篇論文繼續往下深挖。

我會再另外寫一篇實際操作 Speculative Decoding 並觀察其改進推理效果的文章分享,希望能趕快接觸到更多加速推理技巧的全新方法!因為我相信,一旦突破了計算時間的枷鎖,AI 才能真正走入人類社會並造福大眾。


References


Read More

Leave a Reply