Skip to content

Supporting Hydra Speculative Decoding on TensorRT-LLM Python Session

Last Updated on 2025-07-01 by Clay

Introduction

I’ve previously studied many different speculative decoding acceleration techniques and attempted to implement several architectures using PyTorch, including model architecture, training, and inference scripts (fast-llm-inference). This time, of course, I have a new goal.

A month or two ago I read Hydra — we can consider it a variant based on the Medusa architecture. This time, I hope to support the official Hydra weights on the TensorRT-LLM acceleration framework as a small personal side project.

Goal: On TensorRT-LLM, using Python Session, and allowing Hydra Heads to generate reasonable draft tokens for the main model to verify. (If you’re interested in the branch I implemented, it’s not yet discussed with TensorRT-LLM officially — it’s simply hosted on my GitHub: support-spec-decode-hydra)


Hydra Overview

For more details, you can refer to the notes I previously wrote: [Paper Reading] Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding.

In short, Hydra and Medusa are very similar in architecture. Each Head is responsible for decoding outputs at different time steps. However, Hydra introduces dependency on the tokens generated by the previous Heads, which significantly improves the acceptance rate of Hydra Heads.

You can imagine that Medusa predicts without knowing the exact tokens generated by previous Heads — it only sees relatively vague hidden states — while Hydra explicitly tells the Head preparing to generate which token was generated previously (via concatenation of token embeddings).

By reading the official implementation source code (prefix_mlp_head.py), we can see the following:

self.hydra_head = HydraMLP(
    hydra_num_layers=self.hydra_num_layers,
    hydra_num_heads=self.hydra,
    grounded_heads=self.grounded_heads,
    input_embed_fn=self.base_model.model.embed_tokens,
    base_config=self.config,
    lm_head_init_weight=base_model.lm_head.weight.data
)
self.hydra_lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)


In addition to specifying the number of layers and heads, we also need to provide input_embed_fn, which is usually the model’s embedding layer.

In the forward() method of HydraMLP, we can clearly see how the Hydra Heads operate:

def forward(self, base_hidden_states, input_ids=None, noise=None):
    """
    Forward pass of the MLP.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output after the MLP.
    """

    hydra_hidden_states = []
    if self.grounded_heads:
        assert input_ids is not None, "Input ids must be provided for grounded heads"
        with torch.inference_mode():
            input_embeds = self.input_embed_fn(input_ids)
        if noise is not None:
            input_embeds = input_embeds + noise
        hydra_inputs = [base_hidden_states]
        for i in range(self.hydra_num_heads):
            # Move input embeddings back one spot for each hydra head idx
            hydra_inputs.append(torch.roll(input_embeds, shifts=-(i+1), dims=1))
        
        for i in range(self.hydra_num_heads):
            head_input = torch.cat(hydra_inputs[:i + 2], dim=-1) 
            hydra_hidden_states.append(self.hydra_mlp[i](head_input))
    else:
        for i in range(self.hydra_num_heads):
            hydra_hidden_states.append(self.hydra_mlp[i](base_hidden_states))


We can also clearly see how tokens are transformed into embeddings:

with torch.inference_mode():
    input_embeds = self.input_embed_fn(input_ids)


According to experiments by the research team, this is 1.1x faster than Medusa.

Now, I aim to support it on TensorRT-LLM and run it successfully.


Implementation Concepts & Process

So, to actually make it work on TensorRT-LLM, what needs to be done?

First, I defined my goal: since I want to support the official Hydra model, I downloaded the base Vicuna-7B model and the official GitHub repository ankner--hydra-vicuna-7b-v1.3.

After downloading, I read the config.json, which clearly indicated this is a “prefix-mlp” variant — and this design caused me a lot of trouble later on XD

The original Medusa Heads are parallel. The model’s hidden states can be directly passed to each Head for inference. Hydra, however, is different: it first enters a single-layer Llama Model, then concatenates the most probable token’s embedding with the model’s hidden state before passing it to the Hydra Heads.

Additionally, Hydra Heads are not parallel — they are sequential. Token 1 generated by Head 1 is transformed into Embedding 1 and concatenated with earlier features to be passed into Head 2, and so on.

As a result, the input dimensions to the Hydra Heads expand as: 8192, 12288, 16384, 20480… and so on. I’ll explain this challenge in more detail later.

Next, you should understand that TensorRT-LLM has both Python and C++ backends. My goal this time is to implement the Python backend. Also, TensorRT-LLM already supports Medusa — since Hydra is very similar, I mainly mimicked Medusa to redefine the architecture.

So here’s what I did:

  • Created a new hydra directory under tensorrt_llm/models/ (copied from medusa/)
  • Redefined model.py, weight.py, and config.py (GitHub link)
  • To verify the implementation, I created examples/hydra/ and wrote convert.sh, build.sh, and run.sh scripts to confirm the model can run
  • Modified examples/run.py and tensorrt_llm/runtime/generation.py to add Hydra support — mostly copying the Medusa flow but updating paths and flags to use --speculative_decoding_mode hydra

examples/hydra/convert.sh, build.sh, run.sh

First, the script to execute:

#!/bin/bash


python convert_checkpoint.py \
    --model_dir ./lmsys--vicuna-7b-v1.3/ \
    --hydra_model_dir ./ankner--hydra-vicuna-7b-v1.3 \
    --output_dir ./tllm_checkpoint_1gpu_hydra \
    --dtype float16 \
    --num_hydra_heads 4 \
    --num_hydra_layers 4


This script converts a pretrained large language model (e.g., LLaMA, Qwen2) into a highly optimized TensorRT-LLM checkpoint format, allowing it to be built into a fast inference engine.

#!/bin/bash


trtllm-build \
    --checkpoint_dir tllm_checkpoint_1gpu_hydra \
    --output_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
    --gemm_plugin float16 \
    --speculative_decoding_mode hydra \
    --max_batch_size 4

This is the key script to build a Hydra-based LLM with TensorRT-LLM. It reads the checkpoint generated earlier and builds an optimized engine.

If the model definition fails, or if hydra/weight.py encounters issues reading weights, it will error out here. Incorrect graph shapes will also trigger build errors.

#!/bin/bash



python ../run.py \
    --engine_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
    --use_py_session \
    --tokenizer_dir ./lmsys--vicuna-7b-v1.3/ \
    --max_output_len=100 \
    --hydra_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
    --temperature 1.0 \
    --input_text "Once upon" \
    --debug_mode


Now it becomes simple: provide your input_text and test Hydra’s generation results.


Problems Encountered

There were many pitfalls during implementation. Below are a few particularly painful ones that initially seemed unfixable.

In retrospect, the main issues stemmed from:

  1. Couldn’t retrieve Hydra logits: Acceptance rate was 0 at first. Turns out I had to explicitly write medusa_logits.mark_output('medusa_logits', self.config.logits_dtype) in model.py. Also, changing the name caused failures—probably hardcoded to medusa_logits.
  2. CUDA illegal memory access: Debugged this for a long time. Eventually found that the prefix_embedding_layer was the root cause. Since it bypasses the embedding layer (by directly passing hidden states), calling TensorRT-LLM’s internal Llama model resulted in CUDA errors. In the end, I ported Hugging Face’s LlamaModel and rewrote all internal operations with TensorRT-LLM modules to maintain graph optimization.
  3. AssertionError: tensor /concat_L2636/CONCATENATION_1_output_0 has an invalid shape: TensorRT failed to infer dynamic shapes due to ever-growing Hydra input dimensions. Eventually I hardcoded the loop — but that’s a bad solution. I plan to refactor this with dynamic head support.

So my final model.py looks like this:

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
from typing import Optional, Union

import numpy as np
import torch
from transformers import AutoModelForCausalLM

from tensorrt_llm._utils import numpy_to_torch
from tensorrt_llm.models.hydra.weight import load_hydra_hf
from tensorrt_llm.models.llama.model import LLaMAForCausalLM, RmsNorm
from tensorrt_llm.models.qwen.model import QWenForCausalLM

from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import (ACT2FN, add, cast, concat, constant, cos, div,
                           expand, matmul, mul, shape, sin, slice, softmax,
                           squeeze, stack, topk, transpose, unsqueeze, view)
from ...layers import ColumnLinear
from ...mapping import Mapping
from ...module import Module, ModuleList
from ..modeling_utils import PretrainedModel, QuantConfig
from .config import HydraConfig
from .weight import convert_hf_llama


# refer: https://github.com/zankner/Hydra/blob/main/hydra/model/hydra_heads/prefix_mlp_head.py#L44
class HydraResBlock(Module):

    def __init__(
            self,
            hidden_size,
            hidden_act="silu",
            num_condition=0,
            dtype=None,
            mapping=Mapping(),
    ):
        super().__init__()

        input_size = hidden_size * (num_condition + 1)
        self.linear = ColumnLinear(input_size,
                                   hidden_size,
                                   dtype=dtype,
                                   tp_group=mapping.tp_group,
                                   tp_size=mapping.tp_size,
                                   gather_output=True)
        self.res_connection = ColumnLinear(
            input_size,
            hidden_size,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True) if num_condition > 0 else torch.nn.Identity()

        self.hidden_act = hidden_act

    def forward(self, x):
        return self.res_connection(x) + ACT2FN[self.hidden_act](self.linear(x))


class HydraPrefixMLP(Module):

    def __init__(
            self,
            num_layers,
            hidden_size,
            vocab_size,
            hydra_head_idx,
            hidden_act="silu",
            dtype=None,
            mapping=Mapping(),
            lm_head_init_weight=None,
    ):
        super().__init__()
        self.hydra_mlp = HydraResBlock(hidden_size=hidden_size,
                                       num_condition=hydra_head_idx + 1,
                                       hidden_act=hidden_act,
                                       dtype=dtype,
                                       mapping=mapping)

        self.hydra_mlps = ModuleList([
            HydraResBlock(hidden_size=hidden_size,
                          hidden_act=hidden_act,
                          dtype=dtype,
                          mapping=mapping) for _ in range(num_layers)
        ])
        self.hydra_lm_head = ColumnLinear(hidden_size,
                                          vocab_size,
                                          bias=True,
                                          dtype=dtype,
                                          tp_group=mapping.tp_group,
                                          tp_size=mapping.tp_size,
                                          gather_output=True)

    def forward(self, x):
        hidden_states = self.hydra_mlp(x)

        for layer in self.hydra_mlps:
            hidden_states = layer(hidden_states)

        return self.hydra_lm_head(hidden_states)


def _compute_default_rope_parameters(
    config: Optional[HydraConfig] = None,
    **rope_kwargs,
):
    # if len(rope_kwargs) > 0:
    #     base = rope_kwargs["base"]
    #     dim = rope_kwargs["dim"]
    # elif config is not None:
    #     base = config.rope_theta
    #     partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
    #     head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
    #     dim = int(head_dim * partial_rotary_factor)

    base = getattr(config, "rope_theta", 10000.0)
    head_dim = config.hidden_size // config.num_attention_heads
    partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
    dim = int(head_dim * partial_rotary_factor)

    attention_factor = 1.0  # Unused in this type of RoPE

    # Compute the inverse frequencies
    idx = np.arange(0, dim, 2, dtype=np.float32)
    inv_freq = 1.0 / (base**(idx / dim))

    return inv_freq, attention_factor


def _compute_llama3_parameters(
    config: HydraConfig,
    **rope_kwargs,
):
    # Gets the default RoPE parameters
    inv_freq, attention_factor = _compute_default_rope_parameters(
        config, **rope_kwargs)

    factor = 8.0
    low_freq_factor = 1.0
    high_freq_factor = 4.0
    old_context_len = 8192.0

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * math.pi / inv_freq

    # Use numpy
    inv_freq_llama = np.where(wavelen > low_freq_wavelen, inv_freq / factor,
                              inv_freq)

    smooth_factor = (old_context_len / wavelen -
                     low_freq_factor) / (high_freq_factor - low_freq_factor)
    smoothed_inv_freq = (
        1 - smooth_factor
    ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama

    is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen
                                                      >= high_freq_wavelen)

    inv_freq_llama = np.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

    return inv_freq_llama, attention_factor


def rotate_half(x):
    """
    TensorRT-LLM functional version of rotate_half.
    Assumes x is a 4D tensor: [batch, num_heads, seq_len, head_dim]
    Splits the last dimension in half, rotates the halves, and concatenates them.
    """
    # Get dimensions as scalar tensors
    dim0 = squeeze(shape(x, 0), dim=0)
    dim1 = squeeze(shape(x, 1), dim=0)
    dim2 = squeeze(shape(x, 2), dim=0)
    last_dim = squeeze(shape(x, 3), dim=0)

    # Compute half of last_dim
    two = constant(np.array([2], dtype="int64"))
    half_dim = squeeze(div(last_dim, two), dim=0)

    # Create scalar zero for use in starts
    zero = constant(np.array(0, dtype="int64"))

    # Define starts and sizes for slicing
    starts1 = stack([zero, zero, zero, zero], dim=0)
    sizes1 = stack([dim0, dim1, dim2, half_dim], dim=0)
    starts2 = stack([zero, zero, zero, half_dim], dim=0)

    # Slice tensors into two halves along the last dimension
    x1 = slice(x, starts=starts1, sizes=sizes1)
    x2 = slice(x, starts=starts2, sizes=sizes1)

    # Negate the second half and concatenate
    neg_x2 = mul(x2, -1.0)
    return concat([neg_x2, x1], dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """
    TensorRT-LLM functional version of apply_rotary_pos_emb.
    """
    # PyTorch: cos.unsqueeze(unsqueeze_dim)
    cos_expanded = unsqueeze(cos, unsqueeze_dim)
    sin_expanded = unsqueeze(sin, unsqueeze_dim)

    # PyTorch: (q * cos) + (rotate_half(q) * sin)
    rotated_q = rotate_half(q)
    q_embed = add(mul(q, cos_expanded), mul(rotated_q, sin_expanded))

    # PyTorch: (k * cos) + (rotate_half(k) * sin)
    rotated_k = rotate_half(k)
    k_embed = add(mul(k, cos_expanded), mul(rotated_k, sin_expanded))

    return q_embed, k_embed


def repeat_kv(hidden_states, n_rep: int):
    """
    TensorRT-LLM functional version of repeat_kv
    """
    if n_rep == 1:
        return hidden_states

    batch, num_key_value_heads, slen, head_dim = (shape(hidden_states, 0),
                                                  shape(hidden_states, 1),
                                                  shape(hidden_states, 2),
                                                  shape(hidden_states, 3))

    hidden_states_unsqueezed = unsqueeze(hidden_states, 2)
    hidden_states_expanded = expand(
        hidden_states_unsqueezed,
        [batch, num_key_value_heads, n_rep, slen, head_dim])

    final_shape = [batch, num_key_value_heads * n_rep, slen, head_dim]
    return view(hidden_states_expanded, final_shape)


def eager_attention_forward(
    query,
    key,
    value,
    num_key_value_groups: int,
    scaling: float,
    dropout: float = 0.0,
    attention_mask=None,
    **kwargs,
):
    key_states = repeat_kv(key, num_key_value_groups)
    value_states = repeat_kv(value, num_key_value_groups)

    # Attetion Scores: (Q @ K.T) * scaling
    key_states_T = transpose(key_states, 2, 3)
    attn_scores = matmul(query, key_states_T)
    attn_scores_scaled = mul(attn_scores, scaling)

    if attention_mask is not None:
        key_len = shape(key_states, 2)
        mask_shape = shape(attention_mask)

        causal_mask = slice(
            attention_mask,
            starts=[0, 0, 0, 0],
            sizes=[mask_shape[0], mask_shape[1], mask_shape[2], key_len])
        attn_scores_masked = add(attn_scores_scaled, causal_mask)
    else:
        attn_scores_masked = attn_scores_scaled

    # Softmax
    query_dtype = query.dtype
    attn_weights_fp32 = softmax(cast(attn_scores_masked, "float32"), dim=-1)
    attn_weights = cast(attn_weights_fp32, query_dtype)

    # Ignore dropout
    # Compute Attention Output: attn_weights @ V
    attn_output = matmul(attn_weights, value_states)

    # Transpose
    attn_output = transpose(attn_output, 1, 2)

    return attn_output


class LlamaRotaryEmbedding(Module):

    def __init__(self, config: HydraConfig, mapping=Mapping()):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings
        self.config = config
        self.rope_init_fn = _compute_llama3_parameters

        self.original_inv_freq, self.attention_scaling = self.rope_init_fn(
            self.config)

    def forward(self, x, position_ids):
        x_dtype = x.dtype

        inv_freq = constant(self.original_inv_freq)

        inv_freq_unsqueezed = unsqueeze(unsqueeze(inv_freq, 0), 2)
        b = shape(position_ids, 0)
        inv_freq_dim0 = shape(inv_freq, 0)
        one = constant(np.array([1], dtype=np.int64))

        expand_shape = concat(
            [unsqueeze(b, 0), unsqueeze(inv_freq_dim0, 0), one], dim=0)

        inv_freq_expanded = expand(inv_freq_unsqueezed, expand_shape)
        position_ids_expanded = unsqueeze(position_ids, 1)

        inv_freq_float32 = cast(inv_freq_expanded, "float32")
        position_ids_float32 = cast(position_ids_expanded, "float32")

        freqs_t = matmul(inv_freq_float32, position_ids_float32)
        freqs = transpose(freqs_t, 1, 2)

        emb = concat([freqs, freqs], dim=-1)

        # 6. Apply cos, sin, and scaling
        # PyTorch: emb.cos() * self.attention_scaling
        cos_emb = cos(emb)
        cos_scaled = mul(cos_emb, self.attention_scaling)

        # PyTorch: emb.sin() * self.attention_scaling
        sin_emb = sin(emb)
        sin_scaled = mul(sin_emb, self.attention_scaling)

        # 7. Cast back to original dtype
        # PyTorch: .to(dtype=x.dtype)
        final_cos = cast(cos_scaled, x_dtype)
        final_sin = cast(sin_scaled, x_dtype)

        return final_cos, final_sin


class LlamaMLP(Module):

    def __init__(self, config, dtype=None, mapping=Mapping()):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = ColumnLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.up_proj = ColumnLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.down_proj = ColumnLinear(
            in_features=self.intermediate_size,
            out_features=self.hidden_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # PyTorch: self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        gated_x = self.act_fn(self.gate_proj(x))
        up_x = self.up_proj(x)
        fused_x = mul(gated_x, up_x)

        down_proj = self.down_proj(fused_x)
        return down_proj


class LlamaAttention(Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self,
                 config: HydraConfig,
                 layer_idx: int,
                 mapping=Mapping(),
                 dtype=None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(
            config, "head_dim",
            config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = True

        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.tp_size = mapping.tp_size
        self.hidden_size = config.hidden_size

        self.q_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_attention_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.k_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.v_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.o_proj = ColumnLinear(
            in_features=config.num_attention_heads * self.head_dim,
            out_features=config.hidden_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

    def forward(
        self,
        hidden_states,
        position_embeddings,
        attention_mask=None,
        past_key_value=None,
        cache_position=None,
    ):
        b, s = shape(hidden_states, 0), shape(hidden_states, 1)

        # 1. Q, K, V Projections
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # 2. Reshape and Transpose to [batch, num_heads, seq_len, head_dim]
        # PyTorch: .view(hidden_shape).transpose(1, 2)
        query_states = transpose(
            view(
                q,
                [0, 0, self.num_attention_heads // self.tp_size, self.head_dim
                 ]), 1, 2)
        key_states = transpose(
            view(
                k,
                [0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
                 ]), 1, 2)
        value_states = transpose(
            view(
                v,
                [0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
                 ]), 1, 2)

        # 3. Apply Rotary Position Embedding
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states,
                                                        key_states, cos, sin)
        # 4. Maybe pass it...
        # if past_key_value is not None:
        #     # sin and cos are specific to RoPE models; cache_position needed for the static cache
        #     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        #     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 5. Attention Computation
        attn_output = eager_attention_forward(
            query=query_states,
            key=key_states,
            value=value_states,
            num_key_value_groups=self.num_key_value_groups,
            scaling=self.scaling,
            attention_mask=attention_mask,
        )

        # 6. Final Reshape and Projection
        # PyTorch: attn_output.reshape(*input_shape, -1).contiguous()

        attn_output = view(attn_output, [0, 0, -1])
        attn_output = self.o_proj(attn_output)

        return attn_output


class LlamaDecoderLayer(Module):

    def __init__(self, config: HydraConfig, layer_idx: int, mapping=Mapping()):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaAttention(config=config,
                                        layer_idx=layer_idx,
                                        mapping=mapping)

        self.mlp = LlamaMLP(config, mapping=mapping)
        self.input_layernorm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )
        self.post_attention_layernorm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )

    def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_ids=None,
            past_key_value=None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
            cache_position=None,
            position_embeddings=None,  # necessary, but kept here for BC
    ):
        residual = hidden_states
        normed_hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        attn_output = self.self_attn(
            hidden_states=normed_hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
        )
        hidden_states = add(residual, attn_output)

        # Fully Connected
        residual = hidden_states
        normed_hidden_states = self.post_attention_layernorm(hidden_states)

        mlp_output = self.mlp(normed_hidden_states)
        hidden_states = add(residual, mlp_output)

        return hidden_states


# refer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class PrefixEmbeddingLayer(Module):

    def __init__(self, config: HydraConfig, mapping=Mapping()):
        super().__init__()
        self.vocab_size = config.vocab_size

        self.layer = LlamaDecoderLayer(config=config,
                                       layer_idx=0,
                                       mapping=mapping)

        self.norm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )
        self.rotary_emb = LlamaRotaryEmbedding(config=config, mapping=mapping)

    def forward(
        self,
        inputs_embeds,
        position_ids,
        attention_mask=None,
    ):
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        hidden_states = self.layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=False,
            position_embeddings=position_embeddings,
        )

        hidden_states = self.norm(hidden_states)

        return hidden_states


# HydraForCausalLM is a thin wrapper that picks parent class for GenericHydraForCausalLM.
# All hydra functionality is defined in GenericHydraForCausalLM.
class HydraForCausalLM(PretrainedModel):
    config_class = HydraConfig

    def __init__(self, config: HydraConfig):
        super().__init__(config)

        BaseLM = QWenForCausalLM if hasattr(
            config,
            "model_type") and "qwen" in config.model_type else LLaMAForCausalLM

        class GenericHydraForCausalLM(BaseLM):

            def __init__(self, config: HydraConfig):
                super().__init__(config)
                self.num_hydra_heads = config.num_hydra_heads
                self.num_hydra_layers = config.num_hydra_layers
                self.hidden_size = config.hidden_size
                self.vocab_size = config.vocab_size
                vocab_size_padded = pad_vocab_size(self.vocab_size,
                                                   config.mapping.tp_size)

                base_kwargs = config.to_dict()
                prefix_config = BaseLM.config_class(**base_kwargs)
                self.prefix_embedding_layer = PrefixEmbeddingLayer(
                    prefix_config)

                self.hydra_heads = ModuleList([
                    HydraPrefixMLP(num_layers=self.num_hydra_layers - 1,
                                   hidden_size=config.hidden_size,
                                   vocab_size=vocab_size_padded,
                                   hydra_head_idx=i,
                                   hidden_act=config.hidden_act,
                                   dtype=config.dtype,
                                   mapping=config.mapping)
                    for i in range(self.num_hydra_heads)
                ])

                self.input_embed_fn = self.transformer.vocab_embedding
                self.max_hydra_token_len = config.max_draft_len

            def forward(self, *args, **kwargs):
                output_original = True
                hidden_states = super().forward(*args, **kwargs)

                if kwargs['use_cache']:
                    if default_net().plugin_config.paged_kv_cache:
                        lm_logits, hidden_states, _ = hidden_states
                    else:
                        lm_logits, presents, hidden_states = hidden_states

                if self.mapping.is_last_pp_rank():
                    position_ids = kwargs["position_ids"]

                    hidden_states_3d = unsqueeze(
                        hidden_states, 1)  # Shape: [B, H] -> [B, 1, H]

                    prefix_embedding = self.prefix_embedding_layer(
                        inputs_embeds=hidden_states_3d,
                        position_ids=position_ids,
                        attention_mask=None,
                    )

                    _, topk_ids = topk(lm_logits, k=1, dim=-1)
                    next_embedding = self.input_embed_fn(squeeze(topk_ids, -1))

                    # TODO: Need to convert back into for-loop
                    # prefix_embedding and next_embedding are 2D: [batch, hidden_size]
                    # prefix_embedding_3d = unsqueeze(prefix_embedding, 1) # -> [batch, 1, hidden_size]
                    next_embedding_3d = unsqueeze(
                        next_embedding, 1)  # -> [batch, 1, hidden_size]

                    all_head_logits = []

                    # --- Head 0 ---
                    head_0_input = concat([prefix_embedding, next_embedding_3d],
                                          dim=2)
                    # head_0_input = concat([next_embedding_3d, next_embedding_3d], dim=2)
                    head_0_logits = self.hydra_heads[0](head_0_input)
                    all_head_logits.append(squeeze(head_0_logits, dim=1))

                    # --- Head 1 ---
                    _, next_token_ids_1 = topk(head_0_logits, k=1, dim=-1)
                    next_embedding_1 = self.input_embed_fn(
                        squeeze(next_token_ids_1, -1))

                    head_1_input = concat([head_0_input, next_embedding_1],
                                          dim=2)
                    head_1_logits = self.hydra_heads[1](head_1_input)
                    all_head_logits.append(squeeze(head_1_logits, dim=1))

                    # --- Head 2 ---
                    _, next_token_ids_2 = topk(head_1_logits, k=1, dim=-1)
                    next_embedding_2 = self.input_embed_fn(
                        squeeze(next_token_ids_2, -1))

                    head_2_input = concat([head_1_input, next_embedding_2],
                                          dim=2)
                    head_2_logits = self.hydra_heads[2](head_2_input)
                    all_head_logits.append(squeeze(head_2_logits, dim=1))

                    # --- Head 3 ---
                    _, next_token_ids_3 = topk(head_2_logits, k=1, dim=-1)
                    next_embedding_3 = self.input_embed_fn(
                        squeeze(next_token_ids_3, -1))

                    head_3_input = concat([head_2_input, next_embedding_3],
                                          dim=2)
                    head_3_logits = self.hydra_heads[3](head_3_input)
                    all_head_logits.append(squeeze(head_3_logits, dim=1))

                    medusa_logits = stack(all_head_logits, dim=0)
                    medusa_logits.mark_output('medusa_logits',
                                              self.config.logits_dtype)

                else:
                    hidden_states.mark_output('hidden_states_output',
                                              self.config.dtype)

                if kwargs['use_cache'] and default_net(
                ).plugin_config.paged_kv_cache == False:
                    if self.mapping.is_last_pp_rank():
                        if output_original:
                            return (medusa_logits, lm_logits, presents)
                        return (medusa_logits, presents)
                    return (hidden_states, presents)
                else:
                    if self.mapping.is_last_pp_rank():
                        if output_original:
                            return medusa_logits, lm_logits
                        return medusa_logits
                    return hidden_states

            def prepare_inputs(self, *args, **kwargs):
                kwargs['speculative_decoding_draft_tokens_external'] = False
                kwargs['max_draft_len'] = self.max_hydra_token_len
                return super().prepare_inputs(*args, **kwargs)

        self.model = GenericHydraForCausalLM(config)

    # Specialization to redirect accesses to self.model
    def __getattribute__(self, name):
        if name == 'model' or '__' in name:
            return object.__getattribute__(self, name)
        else:
            model = object.__getattribute__(self, 'model')
            return model.__getattribute__(name)

    # Override specialized __setattr__ defined in Module
    def __setattr__(self, name, value) -> None:
        object.__setattr__(self, name, value)

    @classmethod
    def from_hugging_face(
            cls,
            hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
            dtype: str = 'auto',
            mapping: Optional[Mapping] = None,
            quant_config: Optional[QuantConfig] = None,
            **kwargs):
        import transformers

        assert hf_model_or_dir is not None
        speculative_model_dir = kwargs.get('speculative_model', None)

        use_preloading = isinstance(hf_model_or_dir,
                                    transformers.PreTrainedModel)
        if use_preloading:
            hf_model = hf_model_or_dir
            hf_config_or_dir = hf_model.config
        else:
            hf_model_dir = hf_model_or_dir
            hf_config_or_dir = hf_model_or_dir

        config = HydraConfig.from_hugging_face(hf_config_or_dir,
                                               dtype=dtype,
                                               mapping=mapping,
                                               quant_config=quant_config,
                                               **kwargs)

        # ModelOpt ckpt has combined base model and Hydra-head
        is_modelopt_ckpt = True if not speculative_model_dir else False

        if not use_preloading:
            trust_remote_code = kwargs.pop('trust_remote_code', True)

            if is_modelopt_ckpt:
                hf_model = LLaMAForCausalLM.from_hugging_face(
                    hf_model_dir,
                    dtype,
                    mapping=mapping,
                    quant_config=quant_config,
                    **kwargs)
            else:
                hf_model = AutoModelForCausalLM.from_pretrained(
                    hf_model_dir,
                    torch_dtype="auto",
                    trust_remote_code=trust_remote_code)

                assert isinstance(hf_model, transformers.PreTrainedModel)

        if is_modelopt_ckpt:
            weights = {
                name: numpy_to_torch(param.raw_value)
                for name, param in hf_model.named_parameters()
            }
        else:
            weights = convert_hf_llama(
                hf_model,
                config.mapping,
                dtype='float16',
                use_parallel_embedding=config.use_parallel_embedding)

        model = cls(config)

        if is_modelopt_ckpt:
            num_hydra_heads = config.config.num_hydra_heads
            num_hydra_layers = config.config.num_hydra_layers
            speculative_model_dir = hf_model_or_dir
        else:
            config_file = speculative_model_dir / "config.json"
            with open(config_file) as fp:
                model_config = json.load(fp)

            num_hydra_heads = kwargs[
                'speculative_config'].num_hydra_heads if 'speculative_config' in kwargs else model_config.get(
                    'hydra_num_heads', None)
            num_hydra_layers = model_config.get('hydra_num_layers', None)
        hydra_weights = load_hydra_hf(hydra_path=speculative_model_dir,
                                      num_hydra_heads=num_hydra_heads,
                                      num_hydra_layers=num_hydra_layers,
                                      mapping=mapping,
                                      dtype="float16",
                                      base_config=hf_model.config,
                                      is_modelopt_ckpt=is_modelopt_ckpt)
        weights.update(hydra_weights)
        model.load(weights)
        return model

As you can see, most of it is the PrefixEmbeddingLayer. Thankfully, this custom model no longer triggers illegal memory access.

I’m also including the source code that reads Hydra Head weights:

def load_hydra_hf(hydra_path: str,
                  num_hydra_heads: int,
                  num_hydra_layers: int,
                  base_config: PretrainedConfig,
                  mapping=Mapping(),
                  dtype='float32',
                  use_weight_only=False,
                  plugin_weight_only_quant_type=None,
                  is_modelopt_ckpt=False):
    if is_modelopt_ckpt:
        from safetensors.torch import load_file
        state_dict = {}
        for filename in sorted(Path(hydra_path).glob("*.safetensors")):
            print(f"Loading the weights of Hydra heads from {filename}")
            state_dict.update(load_file(filename))
    else:
        is_ckpt_safetensors = False

        ckpt_file = Path(hydra_path) / "hydra_lm_head.pt"
        if not ckpt_file.exists():
            ckpt_file = Path(hydra_path) / "hydra_lm_head.safetensors"
            is_ckpt_safetensors = True

        if is_ckpt_safetensors:
            logger.INFO("Safetensors Found ...")
            from safetensors.torch import load_file
            state_dict = load_file(ckpt_file)
        else:
            state_dict = torch.load(ckpt_file, map_location="cpu")

    torch_dtype = str_dtype_to_torch(dtype)
    weights = {}

    # Embedding
    # embedding_weight = state_dict["prefix_embeding_layer.embed_tokens.weight"].clone().to(torch_dtype)
    # split_emb = split(embedding_weight, mapping.tp_size, mapping.tp_rank, dim=0)
    # weights["prefix_embedding_layer.layer.vocab_embedding.weight"] = split_emb

    # Attention (QKV, O)
    q_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.q_proj.weight"].clone().to(
            torch_dtype)
    k_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.k_proj.weight"].clone().to(
            torch_dtype)
    v_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.v_proj.weight"].clone().to(
            torch_dtype)
    o_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.o_proj.weight"].clone().to(
            torch_dtype)

    weights[
        f"prefix_embedding_layer.layer.self_attn.q_proj.weight"] = split_matrix_tp(
            q_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.k_proj.weight"] = split_matrix_tp(
            k_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.v_proj.weight"] = split_matrix_tp(
            v_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.o_proj.weight"] = split_matrix_tp(
            o_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=1,
        )

    # MLP (fc, gate, proj)
    weights[
        f"prefix_embedding_layer.layer.mlp.gate_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.gate_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.mlp.up_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.up_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.mlp.down_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.down_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=1,
        )

    # LayerNorm (no need to split)
    weights[
        f"prefix_embedding_layer.layer.input_layernorm.weight"] = state_dict[
            f"prefix_embeding_layer.layers.0.input_layernorm.weight"].clone(
            ).to(torch_dtype)
    weights[f"prefix_embedding_layer.layer.post_attention_layernorm.weight"] = state_dict[
        f"prefix_embeding_layer.layers.0.post_attention_layernorm.weight"].clone(
        ).to(torch_dtype)
    weights[f"prefix_embedding_layer.norm.weight"] = state_dict[
        f"prefix_embeding_layer.norm.weight"].clone().to(torch_dtype)

    # Load Hydra heads weights
    for i in range(num_hydra_heads):
        w = state_dict[f"hydra_mlp.{i}.1.linear.weight"].clone().to(torch_dtype)
        weights[f"hydra_heads.{i}.hydra_mlp.linear.weight"] = split(
            w, mapping.tp_size, mapping.tp_rank, dim=0)
        weights[f"hydra_heads.{i}.hydra_mlp.linear.bias"] = state_dict[
            f"hydra_mlp.{i}.1.linear.bias"].clone().to(torch_dtype)

        # res_connection weights
        w_res = state_dict[f"hydra_mlp.{i}.1.res_connection.weight"].clone().to(
            torch_dtype)
        weights[f"hydra_heads.{i}.hydra_mlp.res_connection.weight"] = split(
            w_res, mapping.tp_size, mapping.tp_rank, dim=0)
        weights[f"hydra_heads.{i}.hydra_mlp.res_connection.bias"] = state_dict[
            f"hydra_mlp.{i}.1.res_connection.bias"].clone().to(torch_dtype)

        for l_idx in range(num_hydra_layers - 1):
            seq_idx = 3 + 2 * l_idx  # 3, 5, 7, 9...
            w = state_dict[f"hydra_mlp.{i}.{seq_idx}.linear.weight"].clone().to(
                torch_dtype)
            weights[
                f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.weight"] = split(
                    w, mapping.tp_size, mapping.tp_rank, dim=0)
            weights[
                f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.bias"] = state_dict[
                    f"hydra_mlp.{i}.{seq_idx}.linear.bias"].clone().to(
                        torch_dtype)

        # Load lm_head
        w_lm = state_dict[f"hydra_lm_head.{i}.1.weight"].clone().to(torch_dtype)
        weights[f"hydra_heads.{i}.hydra_lm_head.weight"] = split(
            w_lm, mapping.tp_size, mapping.tp_rank, dim=0)

        if f"hydra_lm_head.{i}.1.bias" in state_dict:
            weights[f"hydra_heads.{i}.hydra_lm_head.bias"] = state_dict[
                f"hydra_lm_head.{i}.1.bias"].clone().to(torch_dtype)

    return weights

Test Results

My machine specs:

  • RTX 4090 x1 (24GB VRAM)
  • 64GB RAM
  • Intel i7 13th Gen CPU

I used the same prompts (such as “Who are you?” and others) for 10 continuous runs at batch_size=1 (with --run_profiling enabled in run.py) to measure the average latency.

This is a rough benchmark, but you can clearly see that while not faster than Medusa, it is indeed faster than native Vicuna-7B!

Hydra was shown to outperform Medusa in the paper, so what happened here? I see two possible explanations:

  1. Hydra Heads are sequentially dependent and not parallel-friendly like Medusa, which may reduce GPU parallelism.
  2. More likely: my implementation isn’t optimal XD. After all, Medusa was officially optimized by the TensorRT-LLM team, while I merely pieced things together and even rewrote a custom LlamaModel just to avoid memory issues, possibly sacrificing some graph optimization.

Conclusion

Supporting Hydra on TensorRT-LLM took me two full weekends and every free evening in a week — around 30 to 35 hours in total. It’s probably the most time-consuming side project I’ve ever done.

Still, hacking through an inference acceleration framework and adding a new speculative decoding mode felt very rewarding. While many modules were copied from Medusa, quite a few issues were unique to Hydra.

It’s hard to document all the fine-grained details of the implementation and debugging. I almost gave up more than once to move on to the next project XD

But now I’m much more familiar with speculative decoding trees on TensorRT-LLM. Maybe I’ll explore other decoding modes or further optimize Hydra — especially the C++ backend and CUDA kernels.

If you’re interested, check out my implementation here: GitHub Repo

I hope to add documentation soon! Thanks for reading~


References


Read More

Leave a Reply