Skip to content

[Paper Reading] Fast Inference from Transformers via Speculative Decoding

Last Updated on 2024-11-06 by Clay

Abstract

In auto-regressive model decoding, if we need to decode K tokens, we must go through the process K times, which is the current bottleneck in the inference time of large language models.

The paper reviewed here presents a method called Speculative Decoding, an accelerated inference algorithm that leverages parallel computation without compromising decoding accuracy.


Introduction

A core observation is that not every task in the model’s decoding process is equally challenging. For instance, precise knowledge is crucial and challenging in generating academic discourse, whereas casual greetings are so simple that grammar might even be overlooked.

These two scenarios naturally vary in complexity.

Researchers aim to find a method that does not require retraining the model or altering the existing output probability distribution, and this approach is inspired by Speculative Execution.

Speculative execution
Speculative execution is an optimization technique where a computer system performs some task that may not be needed. Work is done before it is known whether it is actually needed, so as to prevent a delay that would have to be incurred by doing the work after it is known that it is needed. If it turns out the work was not needed after all, most changes made by the work are reverted and the results are ignored.
Cited from Wikipedia: https://en.wikipedia.org/wiki/Speculative_execution

In the Transformer decoding task, researchers propose using certain approximation models to perform speculative sampling ahead of time for slower target models. The final output is then verified by letting the target model infer the next token, validating all previous speculative samples.

In the example above, green tokens are those generated by the approximation model and accepted by the target model, while red tokens are those rejected by the target model, which then generates the correct token, represented by the blue tokens.

The “rejected suggestions” (red tokens) refer to tokens that the target model would not select at that position due to the probability distribution, whereas the accepted tokens are the ones that closely match the highest probability from the target model’s distribution. This verification process (explained mathematically in the next section) shows why Speculative Decoding is termed a “lossless acceleration” technique.

Thus, using Speculative Decoding can accelerate the process without altering the target model’s output, provided that the approximation model has a faster inference speed.

According to experimental results, pairing a lightweight model with a larger model increases the speed by approximately 2x – 3x.


Speculative Decoding

The basic method, as described in the previous section, relies on the target model simultaneously evaluating all speculative samples. In the worst case, the target model rejects the first token generated by the approximation model and instead uses its own decoding result (available without recalculation). However, if the approximation model is fast enough, the time loss is minimal.

In the best case, the target model accepts the initial t tokens, and since evaluation includes inferring the (t+1)th token, only one unit of inference time is needed to produce (t+1) tokens.

During sampling, various parameters, such as argmax, top-k, nucleus (top-p), and temperature settings, impact logits and sampling modes. Regardless of the parameter settings, Speculative Decoding accelerates the process without altering standard probability distribution sampling.

Next, we examine how the target model evaluates and either accepts or rejects speculative decoding from the approximation model.

Here, I’ll avoid using the original symbols in the paper, opting instead for simpler, custom notations for clarity.

Assume we have a draft model with conditional probability P_{draft}(y_{t}|y_{<t}) for decoding a token at a particular position, and similarly, a target model with conditional probability P_{target}(y_{t}|y_{<t}). We’ll abbreviate these as P_{draft}(x_{t}) and P_{target}(x_{t}).

When a token is generated by the draft model, we use the target model for verification to decide whether to accept it. In practice, we compute a random variable u from a uniform distribution U(0,1).

The token from the draft model is rejected in cases where:

\frac{P_{target}(x_{t})}{P_{draft}(x_{t})} < u

This leads to two scenarios:

  1. P_{target}(x_{t}) > P_{draft}(x_{t}): Here, the ratio is greater than 1, indicating that the token generated by the draft model has an even higher probability in the target model, so it should be accepted.
  2. P_{target}(x_{t}) \le P_{draft}(x_{t}): In this case, the probability distributions of the draft and target models do not align well, necessitating a certain probability of rejection.

The acceptance probability is \frac{P_{target}(x_{t})}{P_{draft}(x_{t})}, while the rejection probability is 1-\frac{P_{target}(x_{t})}{P_{draft}(x_{t})}.

This calculation approach is designed to ensure that the sampling distribution for this token aligns with the target model, maintaining a lossless equivalence.

The following is the proof from the paper:

For any x_{t}, its probability can be expressed as:

P(x=x_{t})=P(accepted, x=x_{t})+P(rejected, x=x_{t})

The probability of acceptance for a token decoded by the draft model is the probability of the draft model decoding the token, multiplied by the verification mechanism’s acceptance probability:

P(accepted, x=x_{t})=P_{draft}(x_{t})min(1,\frac{P_{target}(x_{t})}{P_{draft}(x_{t})})=min(P_{draft}(x_{t}),P_{target}(x_{t}))

For the rejection probability, we introduce \beta to represent the acceptance probability (see Appendix A.1 of the original paper). Thus, 1-\beta is the rejection probability; and $latex P’{target}(x{t})$ represents the adjusted decoding probability by the target model, which we defined as \frac{P_{target}(x_{t})}{P_{draft}(x_{t})} above.

The rejection probability, therefore, is given by:

$latex P(rejected, x=x_{t})=(1-\beta)P’{target}(x{t})=P_{target}(x_{t})-min(P_{draft}(x_{t}),P_{target}(x_{t}))$

Adding the accepted and rejected probabilities yields:

P(accepted, x=x_{t}) + P(rejected, x=x_{t})=P_{target}(x_{t})-min(P_{draft}(x_{t}),P_{target}(x_{t}))+min(P_{draft}(x_{t}),P_{target}(x_{t}))=P_target(x_{t})

Thus, by adjusting the probability to \frac{P_{target}(x_{t})}{P_{draft}(x_{t})}, we ensure it matches the original decoding probability distribution of the target model for that token.

In simpler terms, if the target model’s decoding probability for x_{t} is 0.3, and the draft model’s decoding probability for x_{t} is 0.8, we need to reject this token with a 0.7 probability to align with the target model’s distribution. Since the draft model already sampled it, we adjust the acceptance rate P' so that 0.8 * P' = 0.3, making P'=\frac{0.3}{0.8}=0.375, which naturally maintains the target model’s probability distribution for this token.

Below is the algorithm from the paper, which should enable accurate implementation of Speculative Decoding.


Conclusion

The rest of the paper presents experimental results and discusses their method for calculating actual wall-time improvement, which is worth reading in detail but will not be expanded on here due to space constraints.

In summary, this paper provides an inspiring approach to accelerating inference in large language models, as Speculative Decoding is genuinely practical and valuable. Numerous subsequent works continue to build upon this method.

I plan to write a separate article to test Speculative Decoding and observe its inference improvements in practice, hoping to explore more innovative inference acceleration techniques. I believe that once we break the time constraints of computation, AI will truly integrate into society and benefit the public.


References


Read More

Leave a Reply