Skip to content

[Paper Reading] Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding

Last Updated on 2025-07-01 by Clay

Currently, most of the time spent during LLM inference is bottlenecked by the need to generate tokens sequentially. This highlights a limitation imposed by GPU memory bandwidth — for every single token decoded, the model’s entire weight must be loaded, even though the actual floating-point computation is minimal. This leads to underutilization of the GPU’s computational capabilities.

This is where speculative decoding comes in. It leverages a smaller, faster, but less capable draft model to generate multiple candidate tokens ahead of time. For example, given the input “Hi,” the draft model might auto-generate “, what’s your name?” The target model — the one we aim to accelerate — then performs a one-shot verification by taking the entire sequence Hi, what’s your name? and computing the token-wise predictions. We can then retain the verified portion and truncate the unaccepted part, achieving inference speedup without sacrificing quality.

The choice of the draft model is critical: it must be not only fast but also accurate, so as to yield a high acceptance rate during verification.

A classic example is Medusa ([2401.10774] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads), which designs the draft model as a series of lightweight decoding heads positioned before the target model’s decoding layers. Each head is responsible for decoding results at different time steps.


Hydra Heads

The Hydra Heads proposed in this paper are similar to Medusa Heads in that each head decodes the output for a different time step. However, this approach goes a step further by incorporating the token information decoded by the previous head at each time step, which improves the acceptance rate of Hydra Heads.

Think of it like this: while Medusa makes predictions without knowing the exact token produced by the previous head — relying only on somewhat ambiguous hidden states — Hydra explicitly informs each decoding head of the actual token generated by the preceding head.

P_{draft}(\hat x_{t+i}|x_{\le t},\hat x_{t+1},\ ...,\hat x_{t+i-1})=f_{Hydra,i}(h_{t-1},E_{x_{t}},E_{\hat x_{t+1}},\ ...,E_{\hat x_{t+i-1}})

The original paper writes x_t instead of E_t , but I’ve slightly adjusted the notation to distinguish it from token-level x_i .

By examining the source code, we can see a classic Hydra Head (MLP):

self.hydra_head = HydraMLP(
    hydra_num_layers=self.hydra_num_layers,
    hydra_num_heads=self.hydra,
    grounded_heads=self.grounded_heads,
    input_embed_fn=self.base_model.model.embed_tokens,
    base_config=self.config,
    lm_head_init_weight=base_model.lm_head.weight.data
)
self.hydra_lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 


In addition to specifying the number of layers and heads, the input_embed_fn must also be passed in — typically this refers to the model’s embedding layer.

Inside the forward() function of HydraMLP, we can clearly see how the Hydra Heads operate:

def forward(self, base_hidden_states, input_ids=None, noise=None):
    """
    Forward pass of the MLP.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output after the MLP.
    """

    hydra_hidden_states = []
    if self.grounded_heads:
        assert input_ids is not None, "Input ids must be provided for grounded heads"
        with torch.inference_mode():
            input_embeds = self.input_embed_fn(input_ids)
        if noise is not None:
            input_embeds = input_embeds + noise
        hydra_inputs = [base_hidden_states]
        for i in range(self.hydra_num_heads):
            # Move input embeddings back one spot for each hydra head idx
            hydra_inputs.append(torch.roll(input_embeds, shifts=-(i+1), dims=1))
        
        for i in range(self.hydra_num_heads):
            head_input = torch.cat(hydra_inputs[:i + 2], dim=-1) 
            hydra_hidden_states.append(self.hydra_mlp[i](head_input))
    else:
        for i in range(self.hydra_num_heads):
            hydra_hidden_states.append(self.hydra_mlp[i](base_hidden_states))


We can also clearly see how tokens are transformed into embeddings:

with torch.inference_mode():
    input_embeds = self.input_embed_fn(input_ids)


The research team conducted experiments showing that Hydra achieves a 1.1× speedup over Medusa.


Hydra++

Hydra++ is a training strategy proposed by the researchers to enhance the Hydra architecture. It explores training objectives and structural refinements, pushing throughput to 1.31× that of Medusa and 2.70× the original architecture.

Three effective improvements were introduced (see Appendix A):

  1. Scaling: Each MLP within the heads was scaled to 4 layers — experiments found that going beyond 5 layers offered no further benefit.
  2. Distillation: Applied self-distillation, training Hydra Heads to predict the target model’s output distribution for a given token, rather than the actual next token.
  3. Prefix Attention: To help the draft model better utilize context, a self-attention decoder layer was added outside the target model. This layer is queried only once per decoding step, providing more informative hidden state inputs.

Tree Decoding

Similar to Medusa, Hydra decoding also adopts a pre-defined static tree topology. The challenge then becomes: how do we find the optimal tree structure?

The team used an iterative greedy strategy: starting from a tree with just one node, at each step they identified which existing node, when appended with a child, would increase the expected accepted length — and repeated this process until completion.


Training Details

Model: The base models used were from the Vicuna series (7B, 13B, 33B parameters), which are dialogue-tuned variants of LLaMa.

Training:

  • Only the Draft Heads were trained; base model weights were kept frozen.
  • Training Dataset: ShareGPT (multi-turn dialogue data)
  • Hardware: 8 NVIDIA A100-80GB GPUs
  • Framework: Hugging Face Trainer
  • Optimizer: AdamW (\beta_{1}=0.9,\ \beta_{2}=0.999)
  • Learning Rate: Used a cosine learning rate schedule with warmup. Peak learning rate was 1e-3.
  • Epochs: 1 epoch for Hydra and Medusa Heads (saturation was reached); 10 epochs for Hydra++ Heads.

References


Read More

Leave a Reply