Last Updated on 2025-07-01 by Clay
Introduction
I’ve previously studied many different speculative decoding acceleration techniques and attempted to implement several architectures using PyTorch, including model architecture, training, and inference scripts (fast-llm-inference). This time, of course, I have a new goal.
A month or two ago I read Hydra — we can consider it a variant based on the Medusa architecture. This time, I hope to support the official Hydra weights on the TensorRT-LLM acceleration framework as a small personal side project.
Goal: On TensorRT-LLM, using Python Session, and allowing Hydra Heads to generate reasonable draft tokens for the main model to verify. (If you’re interested in the branch I implemented, it’s not yet discussed with TensorRT-LLM officially — it’s simply hosted on my GitHub: support-spec-decode-hydra)
Hydra Overview
For more details, you can refer to the notes I previously wrote: [Paper Reading] Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding.
In short, Hydra and Medusa are very similar in architecture. Each Head is responsible for decoding outputs at different time steps. However, Hydra introduces dependency on the tokens generated by the previous Heads, which significantly improves the acceptance rate of Hydra Heads.
You can imagine that Medusa predicts without knowing the exact tokens generated by previous Heads — it only sees relatively vague hidden states — while Hydra explicitly tells the Head preparing to generate which token was generated previously (via concatenation of token embeddings).
By reading the official implementation source code (prefix_mlp_head.py), we can see the following:
self.hydra_head = HydraMLP(
hydra_num_layers=self.hydra_num_layers,
hydra_num_heads=self.hydra,
grounded_heads=self.grounded_heads,
input_embed_fn=self.base_model.model.embed_tokens,
base_config=self.config,
lm_head_init_weight=base_model.lm_head.weight.data
)
self.hydra_lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
In addition to specifying the number of layers and heads, we also need to provide input_embed_fn, which is usually the model’s embedding layer.
In the forward() method of HydraMLP, we can clearly see how the Hydra Heads operate:
def forward(self, base_hidden_states, input_ids=None, noise=None):
"""
Forward pass of the MLP.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after the MLP.
"""
hydra_hidden_states = []
if self.grounded_heads:
assert input_ids is not None, "Input ids must be provided for grounded heads"
with torch.inference_mode():
input_embeds = self.input_embed_fn(input_ids)
if noise is not None:
input_embeds = input_embeds + noise
hydra_inputs = [base_hidden_states]
for i in range(self.hydra_num_heads):
# Move input embeddings back one spot for each hydra head idx
hydra_inputs.append(torch.roll(input_embeds, shifts=-(i+1), dims=1))
for i in range(self.hydra_num_heads):
head_input = torch.cat(hydra_inputs[:i + 2], dim=-1)
hydra_hidden_states.append(self.hydra_mlp[i](head_input))
else:
for i in range(self.hydra_num_heads):
hydra_hidden_states.append(self.hydra_mlp[i](base_hidden_states))
We can also clearly see how tokens are transformed into embeddings:
with torch.inference_mode():
input_embeds = self.input_embed_fn(input_ids)
According to experiments by the research team, this is 1.1x faster than Medusa.
Now, I aim to support it on TensorRT-LLM and run it successfully.
Implementation Concepts & Process
So, to actually make it work on TensorRT-LLM, what needs to be done?
First, I defined my goal: since I want to support the official Hydra model, I downloaded the base Vicuna-7B model and the official GitHub repository ankner--hydra-vicuna-7b-v1.3.
After downloading, I read the config.json, which clearly indicated this is a “prefix-mlp” variant — and this design caused me a lot of trouble later on XD
The original Medusa Heads are parallel. The model’s hidden states can be directly passed to each Head for inference. Hydra, however, is different: it first enters a single-layer Llama Model, then concatenates the most probable token’s embedding with the model’s hidden state before passing it to the Hydra Heads.
Additionally, Hydra Heads are not parallel — they are sequential. Token 1 generated by Head 1 is transformed into Embedding 1 and concatenated with earlier features to be passed into Head 2, and so on.
As a result, the input dimensions to the Hydra Heads expand as: 8192, 12288, 16384, 20480… and so on. I’ll explain this challenge in more detail later.
Next, you should understand that TensorRT-LLM has both Python and C++ backends. My goal this time is to implement the Python backend. Also, TensorRT-LLM already supports Medusa — since Hydra is very similar, I mainly mimicked Medusa to redefine the architecture.
So here’s what I did:
- Created a new
hydradirectory undertensorrt_llm/models/(copied frommedusa/) - Redefined
model.py,weight.py, andconfig.py(GitHub link) - To verify the implementation, I created
examples/hydra/and wroteconvert.sh,build.sh, andrun.shscripts to confirm the model can run - Modified
examples/run.pyandtensorrt_llm/runtime/generation.pyto add Hydra support — mostly copying the Medusa flow but updating paths and flags to use--speculative_decoding_mode hydra
examples/hydra/convert.sh, build.sh, run.sh
First, the script to execute:
#!/bin/bash
python convert_checkpoint.py \
--model_dir ./lmsys--vicuna-7b-v1.3/ \
--hydra_model_dir ./ankner--hydra-vicuna-7b-v1.3 \
--output_dir ./tllm_checkpoint_1gpu_hydra \
--dtype float16 \
--num_hydra_heads 4 \
--num_hydra_layers 4
This script converts a pretrained large language model (e.g., LLaMA, Qwen2) into a highly optimized TensorRT-LLM checkpoint format, allowing it to be built into a fast inference engine.
#!/bin/bash
trtllm-build \
--checkpoint_dir tllm_checkpoint_1gpu_hydra \
--output_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
--gemm_plugin float16 \
--speculative_decoding_mode hydra \
--max_batch_size 4
This is the key script to build a Hydra-based LLM with TensorRT-LLM. It reads the checkpoint generated earlier and builds an optimized engine.
If the model definition fails, or if hydra/weight.py encounters issues reading weights, it will error out here. Incorrect graph shapes will also trigger build errors.
#!/bin/bash
python ../run.py \
--engine_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
--use_py_session \
--tokenizer_dir ./lmsys--vicuna-7b-v1.3/ \
--max_output_len=100 \
--hydra_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
--temperature 1.0 \
--input_text "Once upon" \
--debug_mode
Now it becomes simple: provide your input_text and test Hydra’s generation results.
Problems Encountered
There were many pitfalls during implementation. Below are a few particularly painful ones that initially seemed unfixable.
In retrospect, the main issues stemmed from:
- Couldn’t retrieve Hydra logits: Acceptance rate was 0 at first. Turns out I had to explicitly write
medusa_logits.mark_output('medusa_logits', self.config.logits_dtype)inmodel.py. Also, changing the name caused failures—probably hardcoded tomedusa_logits. - CUDA illegal memory access: Debugged this for a long time. Eventually found that the
prefix_embedding_layerwas the root cause. Since it bypasses the embedding layer (by directly passing hidden states), calling TensorRT-LLM’s internal Llama model resulted in CUDA errors. In the end, I ported Hugging Face’s LlamaModel and rewrote all internal operations with TensorRT-LLM modules to maintain graph optimization. AssertionError: tensor /concat_L2636/CONCATENATION_1_output_0 has an invalid shape: TensorRT failed to infer dynamic shapes due to ever-growing Hydra input dimensions. Eventually I hardcoded the loop — but that’s a bad solution. I plan to refactor this with dynamic head support.
So my final model.py looks like this:
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
from typing import Optional, Union
import numpy as np
import torch
from transformers import AutoModelForCausalLM
from tensorrt_llm._utils import numpy_to_torch
from tensorrt_llm.models.hydra.weight import load_hydra_hf
from tensorrt_llm.models.llama.model import LLaMAForCausalLM, RmsNorm
from tensorrt_llm.models.qwen.model import QWenForCausalLM
from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import (ACT2FN, add, cast, concat, constant, cos, div,
expand, matmul, mul, shape, sin, slice, softmax,
squeeze, stack, topk, transpose, unsqueeze, view)
from ...layers import ColumnLinear
from ...mapping import Mapping
from ...module import Module, ModuleList
from ..modeling_utils import PretrainedModel, QuantConfig
from .config import HydraConfig
from .weight import convert_hf_llama
# refer: https://github.com/zankner/Hydra/blob/main/hydra/model/hydra_heads/prefix_mlp_head.py#L44
class HydraResBlock(Module):
def __init__(
self,
hidden_size,
hidden_act="silu",
num_condition=0,
dtype=None,
mapping=Mapping(),
):
super().__init__()
input_size = hidden_size * (num_condition + 1)
self.linear = ColumnLinear(input_size,
hidden_size,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True)
self.res_connection = ColumnLinear(
input_size,
hidden_size,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True) if num_condition > 0 else torch.nn.Identity()
self.hidden_act = hidden_act
def forward(self, x):
return self.res_connection(x) + ACT2FN[self.hidden_act](self.linear(x))
class HydraPrefixMLP(Module):
def __init__(
self,
num_layers,
hidden_size,
vocab_size,
hydra_head_idx,
hidden_act="silu",
dtype=None,
mapping=Mapping(),
lm_head_init_weight=None,
):
super().__init__()
self.hydra_mlp = HydraResBlock(hidden_size=hidden_size,
num_condition=hydra_head_idx + 1,
hidden_act=hidden_act,
dtype=dtype,
mapping=mapping)
self.hydra_mlps = ModuleList([
HydraResBlock(hidden_size=hidden_size,
hidden_act=hidden_act,
dtype=dtype,
mapping=mapping) for _ in range(num_layers)
])
self.hydra_lm_head = ColumnLinear(hidden_size,
vocab_size,
bias=True,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True)
def forward(self, x):
hidden_states = self.hydra_mlp(x)
for layer in self.hydra_mlps:
hidden_states = layer(hidden_states)
return self.hydra_lm_head(hidden_states)
def _compute_default_rope_parameters(
config: Optional[HydraConfig] = None,
**rope_kwargs,
):
# if len(rope_kwargs) > 0:
# base = rope_kwargs["base"]
# dim = rope_kwargs["dim"]
# elif config is not None:
# base = config.rope_theta
# partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
# head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
# dim = int(head_dim * partial_rotary_factor)
base = getattr(config, "rope_theta", 10000.0)
head_dim = config.hidden_size // config.num_attention_heads
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
idx = np.arange(0, dim, 2, dtype=np.float32)
inv_freq = 1.0 / (base**(idx / dim))
return inv_freq, attention_factor
def _compute_llama3_parameters(
config: HydraConfig,
**rope_kwargs,
):
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(
config, **rope_kwargs)
factor = 8.0
low_freq_factor = 1.0
high_freq_factor = 4.0
old_context_len = 8192.0
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
# Use numpy
inv_freq_llama = np.where(wavelen > low_freq_wavelen, inv_freq / factor,
inv_freq)
smooth_factor = (old_context_len / wavelen -
low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen
>= high_freq_wavelen)
inv_freq_llama = np.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama, attention_factor
def rotate_half(x):
"""
TensorRT-LLM functional version of rotate_half.
Assumes x is a 4D tensor: [batch, num_heads, seq_len, head_dim]
Splits the last dimension in half, rotates the halves, and concatenates them.
"""
# Get dimensions as scalar tensors
dim0 = squeeze(shape(x, 0), dim=0)
dim1 = squeeze(shape(x, 1), dim=0)
dim2 = squeeze(shape(x, 2), dim=0)
last_dim = squeeze(shape(x, 3), dim=0)
# Compute half of last_dim
two = constant(np.array([2], dtype="int64"))
half_dim = squeeze(div(last_dim, two), dim=0)
# Create scalar zero for use in starts
zero = constant(np.array(0, dtype="int64"))
# Define starts and sizes for slicing
starts1 = stack([zero, zero, zero, zero], dim=0)
sizes1 = stack([dim0, dim1, dim2, half_dim], dim=0)
starts2 = stack([zero, zero, zero, half_dim], dim=0)
# Slice tensors into two halves along the last dimension
x1 = slice(x, starts=starts1, sizes=sizes1)
x2 = slice(x, starts=starts2, sizes=sizes1)
# Negate the second half and concatenate
neg_x2 = mul(x2, -1.0)
return concat([neg_x2, x1], dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""
TensorRT-LLM functional version of apply_rotary_pos_emb.
"""
# PyTorch: cos.unsqueeze(unsqueeze_dim)
cos_expanded = unsqueeze(cos, unsqueeze_dim)
sin_expanded = unsqueeze(sin, unsqueeze_dim)
# PyTorch: (q * cos) + (rotate_half(q) * sin)
rotated_q = rotate_half(q)
q_embed = add(mul(q, cos_expanded), mul(rotated_q, sin_expanded))
# PyTorch: (k * cos) + (rotate_half(k) * sin)
rotated_k = rotate_half(k)
k_embed = add(mul(k, cos_expanded), mul(rotated_k, sin_expanded))
return q_embed, k_embed
def repeat_kv(hidden_states, n_rep: int):
"""
TensorRT-LLM functional version of repeat_kv
"""
if n_rep == 1:
return hidden_states
batch, num_key_value_heads, slen, head_dim = (shape(hidden_states, 0),
shape(hidden_states, 1),
shape(hidden_states, 2),
shape(hidden_states, 3))
hidden_states_unsqueezed = unsqueeze(hidden_states, 2)
hidden_states_expanded = expand(
hidden_states_unsqueezed,
[batch, num_key_value_heads, n_rep, slen, head_dim])
final_shape = [batch, num_key_value_heads * n_rep, slen, head_dim]
return view(hidden_states_expanded, final_shape)
def eager_attention_forward(
query,
key,
value,
num_key_value_groups: int,
scaling: float,
dropout: float = 0.0,
attention_mask=None,
**kwargs,
):
key_states = repeat_kv(key, num_key_value_groups)
value_states = repeat_kv(value, num_key_value_groups)
# Attetion Scores: (Q @ K.T) * scaling
key_states_T = transpose(key_states, 2, 3)
attn_scores = matmul(query, key_states_T)
attn_scores_scaled = mul(attn_scores, scaling)
if attention_mask is not None:
key_len = shape(key_states, 2)
mask_shape = shape(attention_mask)
causal_mask = slice(
attention_mask,
starts=[0, 0, 0, 0],
sizes=[mask_shape[0], mask_shape[1], mask_shape[2], key_len])
attn_scores_masked = add(attn_scores_scaled, causal_mask)
else:
attn_scores_masked = attn_scores_scaled
# Softmax
query_dtype = query.dtype
attn_weights_fp32 = softmax(cast(attn_scores_masked, "float32"), dim=-1)
attn_weights = cast(attn_weights_fp32, query_dtype)
# Ignore dropout
# Compute Attention Output: attn_weights @ V
attn_output = matmul(attn_weights, value_states)
# Transpose
attn_output = transpose(attn_output, 1, 2)
return attn_output
class LlamaRotaryEmbedding(Module):
def __init__(self, config: HydraConfig, mapping=Mapping()):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = _compute_llama3_parameters
self.original_inv_freq, self.attention_scaling = self.rope_init_fn(
self.config)
def forward(self, x, position_ids):
x_dtype = x.dtype
inv_freq = constant(self.original_inv_freq)
inv_freq_unsqueezed = unsqueeze(unsqueeze(inv_freq, 0), 2)
b = shape(position_ids, 0)
inv_freq_dim0 = shape(inv_freq, 0)
one = constant(np.array([1], dtype=np.int64))
expand_shape = concat(
[unsqueeze(b, 0), unsqueeze(inv_freq_dim0, 0), one], dim=0)
inv_freq_expanded = expand(inv_freq_unsqueezed, expand_shape)
position_ids_expanded = unsqueeze(position_ids, 1)
inv_freq_float32 = cast(inv_freq_expanded, "float32")
position_ids_float32 = cast(position_ids_expanded, "float32")
freqs_t = matmul(inv_freq_float32, position_ids_float32)
freqs = transpose(freqs_t, 1, 2)
emb = concat([freqs, freqs], dim=-1)
# 6. Apply cos, sin, and scaling
# PyTorch: emb.cos() * self.attention_scaling
cos_emb = cos(emb)
cos_scaled = mul(cos_emb, self.attention_scaling)
# PyTorch: emb.sin() * self.attention_scaling
sin_emb = sin(emb)
sin_scaled = mul(sin_emb, self.attention_scaling)
# 7. Cast back to original dtype
# PyTorch: .to(dtype=x.dtype)
final_cos = cast(cos_scaled, x_dtype)
final_sin = cast(sin_scaled, x_dtype)
return final_cos, final_sin
class LlamaMLP(Module):
def __init__(self, config, dtype=None, mapping=Mapping()):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = ColumnLinear(
in_features=self.hidden_size,
out_features=self.intermediate_size,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.up_proj = ColumnLinear(
in_features=self.hidden_size,
out_features=self.intermediate_size,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.down_proj = ColumnLinear(
in_features=self.intermediate_size,
out_features=self.hidden_size,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
# PyTorch: self.act_fn(self.gate_proj(x)) * self.up_proj(x)
gated_x = self.act_fn(self.gate_proj(x))
up_x = self.up_proj(x)
fused_x = mul(gated_x, up_x)
down_proj = self.down_proj(fused_x)
return down_proj
class LlamaAttention(Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self,
config: HydraConfig,
layer_idx: int,
mapping=Mapping(),
dtype=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim",
config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.tp_size = mapping.tp_size
self.hidden_size = config.hidden_size
self.q_proj = ColumnLinear(
in_features=config.hidden_size,
out_features=config.num_attention_heads * self.head_dim,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.k_proj = ColumnLinear(
in_features=config.hidden_size,
out_features=config.num_key_value_heads * self.head_dim,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.v_proj = ColumnLinear(
in_features=config.hidden_size,
out_features=config.num_key_value_heads * self.head_dim,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
self.o_proj = ColumnLinear(
in_features=config.num_attention_heads * self.head_dim,
out_features=config.hidden_size,
bias=False,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
def forward(
self,
hidden_states,
position_embeddings,
attention_mask=None,
past_key_value=None,
cache_position=None,
):
b, s = shape(hidden_states, 0), shape(hidden_states, 1)
# 1. Q, K, V Projections
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# 2. Reshape and Transpose to [batch, num_heads, seq_len, head_dim]
# PyTorch: .view(hidden_shape).transpose(1, 2)
query_states = transpose(
view(
q,
[0, 0, self.num_attention_heads // self.tp_size, self.head_dim
]), 1, 2)
key_states = transpose(
view(
k,
[0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
]), 1, 2)
value_states = transpose(
view(
v,
[0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
]), 1, 2)
# 3. Apply Rotary Position Embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states,
key_states, cos, sin)
# 4. Maybe pass it...
# if past_key_value is not None:
# # sin and cos are specific to RoPE models; cache_position needed for the static cache
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# 5. Attention Computation
attn_output = eager_attention_forward(
query=query_states,
key=key_states,
value=value_states,
num_key_value_groups=self.num_key_value_groups,
scaling=self.scaling,
attention_mask=attention_mask,
)
# 6. Final Reshape and Projection
# PyTorch: attn_output.reshape(*input_shape, -1).contiguous()
attn_output = view(attn_output, [0, 0, -1])
attn_output = self.o_proj(attn_output)
return attn_output
class LlamaDecoderLayer(Module):
def __init__(self, config: HydraConfig, layer_idx: int, mapping=Mapping()):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config,
layer_idx=layer_idx,
mapping=mapping)
self.mlp = LlamaMLP(config, mapping=mapping)
self.input_layernorm = RmsNorm(
normalized_shape=config.hidden_size,
dtype=config.dtype,
)
self.post_attention_layernorm = RmsNorm(
normalized_shape=config.hidden_size,
dtype=config.dtype,
)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position=None,
position_embeddings=None, # necessary, but kept here for BC
):
residual = hidden_states
normed_hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
hidden_states=normed_hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = add(residual, attn_output)
# Fully Connected
residual = hidden_states
normed_hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(normed_hidden_states)
hidden_states = add(residual, mlp_output)
return hidden_states
# refer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class PrefixEmbeddingLayer(Module):
def __init__(self, config: HydraConfig, mapping=Mapping()):
super().__init__()
self.vocab_size = config.vocab_size
self.layer = LlamaDecoderLayer(config=config,
layer_idx=0,
mapping=mapping)
self.norm = RmsNorm(
normalized_shape=config.hidden_size,
dtype=config.dtype,
)
self.rotary_emb = LlamaRotaryEmbedding(config=config, mapping=mapping)
def forward(
self,
inputs_embeds,
position_ids,
attention_mask=None,
):
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
hidden_states = self.layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
position_embeddings=position_embeddings,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# HydraForCausalLM is a thin wrapper that picks parent class for GenericHydraForCausalLM.
# All hydra functionality is defined in GenericHydraForCausalLM.
class HydraForCausalLM(PretrainedModel):
config_class = HydraConfig
def __init__(self, config: HydraConfig):
super().__init__(config)
BaseLM = QWenForCausalLM if hasattr(
config,
"model_type") and "qwen" in config.model_type else LLaMAForCausalLM
class GenericHydraForCausalLM(BaseLM):
def __init__(self, config: HydraConfig):
super().__init__(config)
self.num_hydra_heads = config.num_hydra_heads
self.num_hydra_layers = config.num_hydra_layers
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
vocab_size_padded = pad_vocab_size(self.vocab_size,
config.mapping.tp_size)
base_kwargs = config.to_dict()
prefix_config = BaseLM.config_class(**base_kwargs)
self.prefix_embedding_layer = PrefixEmbeddingLayer(
prefix_config)
self.hydra_heads = ModuleList([
HydraPrefixMLP(num_layers=self.num_hydra_layers - 1,
hidden_size=config.hidden_size,
vocab_size=vocab_size_padded,
hydra_head_idx=i,
hidden_act=config.hidden_act,
dtype=config.dtype,
mapping=config.mapping)
for i in range(self.num_hydra_heads)
])
self.input_embed_fn = self.transformer.vocab_embedding
self.max_hydra_token_len = config.max_draft_len
def forward(self, *args, **kwargs):
output_original = True
hidden_states = super().forward(*args, **kwargs)
if kwargs['use_cache']:
if default_net().plugin_config.paged_kv_cache:
lm_logits, hidden_states, _ = hidden_states
else:
lm_logits, presents, hidden_states = hidden_states
if self.mapping.is_last_pp_rank():
position_ids = kwargs["position_ids"]
hidden_states_3d = unsqueeze(
hidden_states, 1) # Shape: [B, H] -> [B, 1, H]
prefix_embedding = self.prefix_embedding_layer(
inputs_embeds=hidden_states_3d,
position_ids=position_ids,
attention_mask=None,
)
_, topk_ids = topk(lm_logits, k=1, dim=-1)
next_embedding = self.input_embed_fn(squeeze(topk_ids, -1))
# TODO: Need to convert back into for-loop
# prefix_embedding and next_embedding are 2D: [batch, hidden_size]
# prefix_embedding_3d = unsqueeze(prefix_embedding, 1) # -> [batch, 1, hidden_size]
next_embedding_3d = unsqueeze(
next_embedding, 1) # -> [batch, 1, hidden_size]
all_head_logits = []
# --- Head 0 ---
head_0_input = concat([prefix_embedding, next_embedding_3d],
dim=2)
# head_0_input = concat([next_embedding_3d, next_embedding_3d], dim=2)
head_0_logits = self.hydra_heads[0](head_0_input)
all_head_logits.append(squeeze(head_0_logits, dim=1))
# --- Head 1 ---
_, next_token_ids_1 = topk(head_0_logits, k=1, dim=-1)
next_embedding_1 = self.input_embed_fn(
squeeze(next_token_ids_1, -1))
head_1_input = concat([head_0_input, next_embedding_1],
dim=2)
head_1_logits = self.hydra_heads[1](head_1_input)
all_head_logits.append(squeeze(head_1_logits, dim=1))
# --- Head 2 ---
_, next_token_ids_2 = topk(head_1_logits, k=1, dim=-1)
next_embedding_2 = self.input_embed_fn(
squeeze(next_token_ids_2, -1))
head_2_input = concat([head_1_input, next_embedding_2],
dim=2)
head_2_logits = self.hydra_heads[2](head_2_input)
all_head_logits.append(squeeze(head_2_logits, dim=1))
# --- Head 3 ---
_, next_token_ids_3 = topk(head_2_logits, k=1, dim=-1)
next_embedding_3 = self.input_embed_fn(
squeeze(next_token_ids_3, -1))
head_3_input = concat([head_2_input, next_embedding_3],
dim=2)
head_3_logits = self.hydra_heads[3](head_3_input)
all_head_logits.append(squeeze(head_3_logits, dim=1))
medusa_logits = stack(all_head_logits, dim=0)
medusa_logits.mark_output('medusa_logits',
self.config.logits_dtype)
else:
hidden_states.mark_output('hidden_states_output',
self.config.dtype)
if kwargs['use_cache'] and default_net(
).plugin_config.paged_kv_cache == False:
if self.mapping.is_last_pp_rank():
if output_original:
return (medusa_logits, lm_logits, presents)
return (medusa_logits, presents)
return (hidden_states, presents)
else:
if self.mapping.is_last_pp_rank():
if output_original:
return medusa_logits, lm_logits
return medusa_logits
return hidden_states
def prepare_inputs(self, *args, **kwargs):
kwargs['speculative_decoding_draft_tokens_external'] = False
kwargs['max_draft_len'] = self.max_hydra_token_len
return super().prepare_inputs(*args, **kwargs)
self.model = GenericHydraForCausalLM(config)
# Specialization to redirect accesses to self.model
def __getattribute__(self, name):
if name == 'model' or '__' in name:
return object.__getattribute__(self, name)
else:
model = object.__getattribute__(self, 'model')
return model.__getattribute__(name)
# Override specialized __setattr__ defined in Module
def __setattr__(self, name, value) -> None:
object.__setattr__(self, name, value)
@classmethod
def from_hugging_face(
cls,
hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs):
import transformers
assert hf_model_or_dir is not None
speculative_model_dir = kwargs.get('speculative_model', None)
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)
if use_preloading:
hf_model = hf_model_or_dir
hf_config_or_dir = hf_model.config
else:
hf_model_dir = hf_model_or_dir
hf_config_or_dir = hf_model_or_dir
config = HydraConfig.from_hugging_face(hf_config_or_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
# ModelOpt ckpt has combined base model and Hydra-head
is_modelopt_ckpt = True if not speculative_model_dir else False
if not use_preloading:
trust_remote_code = kwargs.pop('trust_remote_code', True)
if is_modelopt_ckpt:
hf_model = LLaMAForCausalLM.from_hugging_face(
hf_model_dir,
dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype="auto",
trust_remote_code=trust_remote_code)
assert isinstance(hf_model, transformers.PreTrainedModel)
if is_modelopt_ckpt:
weights = {
name: numpy_to_torch(param.raw_value)
for name, param in hf_model.named_parameters()
}
else:
weights = convert_hf_llama(
hf_model,
config.mapping,
dtype='float16',
use_parallel_embedding=config.use_parallel_embedding)
model = cls(config)
if is_modelopt_ckpt:
num_hydra_heads = config.config.num_hydra_heads
num_hydra_layers = config.config.num_hydra_layers
speculative_model_dir = hf_model_or_dir
else:
config_file = speculative_model_dir / "config.json"
with open(config_file) as fp:
model_config = json.load(fp)
num_hydra_heads = kwargs[
'speculative_config'].num_hydra_heads if 'speculative_config' in kwargs else model_config.get(
'hydra_num_heads', None)
num_hydra_layers = model_config.get('hydra_num_layers', None)
hydra_weights = load_hydra_hf(hydra_path=speculative_model_dir,
num_hydra_heads=num_hydra_heads,
num_hydra_layers=num_hydra_layers,
mapping=mapping,
dtype="float16",
base_config=hf_model.config,
is_modelopt_ckpt=is_modelopt_ckpt)
weights.update(hydra_weights)
model.load(weights)
return model
As you can see, most of it is the PrefixEmbeddingLayer. Thankfully, this custom model no longer triggers illegal memory access.
I’m also including the source code that reads Hydra Head weights:
def load_hydra_hf(hydra_path: str,
num_hydra_heads: int,
num_hydra_layers: int,
base_config: PretrainedConfig,
mapping=Mapping(),
dtype='float32',
use_weight_only=False,
plugin_weight_only_quant_type=None,
is_modelopt_ckpt=False):
if is_modelopt_ckpt:
from safetensors.torch import load_file
state_dict = {}
for filename in sorted(Path(hydra_path).glob("*.safetensors")):
print(f"Loading the weights of Hydra heads from {filename}")
state_dict.update(load_file(filename))
else:
is_ckpt_safetensors = False
ckpt_file = Path(hydra_path) / "hydra_lm_head.pt"
if not ckpt_file.exists():
ckpt_file = Path(hydra_path) / "hydra_lm_head.safetensors"
is_ckpt_safetensors = True
if is_ckpt_safetensors:
logger.INFO("Safetensors Found ...")
from safetensors.torch import load_file
state_dict = load_file(ckpt_file)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
torch_dtype = str_dtype_to_torch(dtype)
weights = {}
# Embedding
# embedding_weight = state_dict["prefix_embeding_layer.embed_tokens.weight"].clone().to(torch_dtype)
# split_emb = split(embedding_weight, mapping.tp_size, mapping.tp_rank, dim=0)
# weights["prefix_embedding_layer.layer.vocab_embedding.weight"] = split_emb
# Attention (QKV, O)
q_w = state_dict[
f"prefix_embeding_layer.layers.0.self_attn.q_proj.weight"].clone().to(
torch_dtype)
k_w = state_dict[
f"prefix_embeding_layer.layers.0.self_attn.k_proj.weight"].clone().to(
torch_dtype)
v_w = state_dict[
f"prefix_embeding_layer.layers.0.self_attn.v_proj.weight"].clone().to(
torch_dtype)
o_w = state_dict[
f"prefix_embeding_layer.layers.0.self_attn.o_proj.weight"].clone().to(
torch_dtype)
weights[
f"prefix_embedding_layer.layer.self_attn.q_proj.weight"] = split_matrix_tp(
q_w,
mapping.tp_size,
mapping.tp_rank,
dim=0,
)
weights[
f"prefix_embedding_layer.layer.self_attn.k_proj.weight"] = split_matrix_tp(
k_w,
mapping.tp_size,
mapping.tp_rank,
dim=0,
)
weights[
f"prefix_embedding_layer.layer.self_attn.v_proj.weight"] = split_matrix_tp(
v_w,
mapping.tp_size,
mapping.tp_rank,
dim=0,
)
weights[
f"prefix_embedding_layer.layer.self_attn.o_proj.weight"] = split_matrix_tp(
o_w,
mapping.tp_size,
mapping.tp_rank,
dim=1,
)
# MLP (fc, gate, proj)
weights[
f"prefix_embedding_layer.layer.mlp.gate_proj.weight"] = split_matrix_tp(
state_dict[f"prefix_embeding_layer.layers.0.mlp.gate_proj.weight"].
clone().to(torch_dtype),
mapping.tp_size,
mapping.tp_rank,
dim=0,
)
weights[
f"prefix_embedding_layer.layer.mlp.up_proj.weight"] = split_matrix_tp(
state_dict[f"prefix_embeding_layer.layers.0.mlp.up_proj.weight"].
clone().to(torch_dtype),
mapping.tp_size,
mapping.tp_rank,
dim=0,
)
weights[
f"prefix_embedding_layer.layer.mlp.down_proj.weight"] = split_matrix_tp(
state_dict[f"prefix_embeding_layer.layers.0.mlp.down_proj.weight"].
clone().to(torch_dtype),
mapping.tp_size,
mapping.tp_rank,
dim=1,
)
# LayerNorm (no need to split)
weights[
f"prefix_embedding_layer.layer.input_layernorm.weight"] = state_dict[
f"prefix_embeding_layer.layers.0.input_layernorm.weight"].clone(
).to(torch_dtype)
weights[f"prefix_embedding_layer.layer.post_attention_layernorm.weight"] = state_dict[
f"prefix_embeding_layer.layers.0.post_attention_layernorm.weight"].clone(
).to(torch_dtype)
weights[f"prefix_embedding_layer.norm.weight"] = state_dict[
f"prefix_embeding_layer.norm.weight"].clone().to(torch_dtype)
# Load Hydra heads weights
for i in range(num_hydra_heads):
w = state_dict[f"hydra_mlp.{i}.1.linear.weight"].clone().to(torch_dtype)
weights[f"hydra_heads.{i}.hydra_mlp.linear.weight"] = split(
w, mapping.tp_size, mapping.tp_rank, dim=0)
weights[f"hydra_heads.{i}.hydra_mlp.linear.bias"] = state_dict[
f"hydra_mlp.{i}.1.linear.bias"].clone().to(torch_dtype)
# res_connection weights
w_res = state_dict[f"hydra_mlp.{i}.1.res_connection.weight"].clone().to(
torch_dtype)
weights[f"hydra_heads.{i}.hydra_mlp.res_connection.weight"] = split(
w_res, mapping.tp_size, mapping.tp_rank, dim=0)
weights[f"hydra_heads.{i}.hydra_mlp.res_connection.bias"] = state_dict[
f"hydra_mlp.{i}.1.res_connection.bias"].clone().to(torch_dtype)
for l_idx in range(num_hydra_layers - 1):
seq_idx = 3 + 2 * l_idx # 3, 5, 7, 9...
w = state_dict[f"hydra_mlp.{i}.{seq_idx}.linear.weight"].clone().to(
torch_dtype)
weights[
f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.weight"] = split(
w, mapping.tp_size, mapping.tp_rank, dim=0)
weights[
f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.bias"] = state_dict[
f"hydra_mlp.{i}.{seq_idx}.linear.bias"].clone().to(
torch_dtype)
# Load lm_head
w_lm = state_dict[f"hydra_lm_head.{i}.1.weight"].clone().to(torch_dtype)
weights[f"hydra_heads.{i}.hydra_lm_head.weight"] = split(
w_lm, mapping.tp_size, mapping.tp_rank, dim=0)
if f"hydra_lm_head.{i}.1.bias" in state_dict:
weights[f"hydra_heads.{i}.hydra_lm_head.bias"] = state_dict[
f"hydra_lm_head.{i}.1.bias"].clone().to(torch_dtype)
return weights
Test Results
My machine specs:
- RTX 4090 x1 (24GB VRAM)
- 64GB RAM
- Intel i7 13th Gen CPU
I used the same prompts (such as “Who are you?” and others) for 10 continuous runs at batch_size=1 (with --run_profiling enabled in run.py) to measure the average latency.
This is a rough benchmark, but you can clearly see that while not faster than Medusa, it is indeed faster than native Vicuna-7B!
Hydra was shown to outperform Medusa in the paper, so what happened here? I see two possible explanations:
- Hydra Heads are sequentially dependent and not parallel-friendly like Medusa, which may reduce GPU parallelism.
- More likely: my implementation isn’t optimal XD. After all, Medusa was officially optimized by the TensorRT-LLM team, while I merely pieced things together and even rewrote a custom LlamaModel just to avoid memory issues, possibly sacrificing some graph optimization.
Conclusion
Supporting Hydra on TensorRT-LLM took me two full weekends and every free evening in a week — around 30 to 35 hours in total. It’s probably the most time-consuming side project I’ve ever done.
Still, hacking through an inference acceleration framework and adding a new speculative decoding mode felt very rewarding. While many modules were copied from Medusa, quite a few issues were unique to Hydra.
It’s hard to document all the fine-grained details of the implementation and debugging. I almost gave up more than once to move on to the next project XD
But now I’m much more familiar with speculative decoding trees on TensorRT-LLM. Maybe I’ll explore other decoding modes or further optimize Hydra — especially the C++ backend and CUDA kernels.
If you’re interested, check out my implementation here: GitHub Repo
I hope to add documentation soon! Thanks for reading~