Skip to content

在 TensorRT-LLM Python Session 上支援 Hydra Speculative Decoding

Last Updated on 2025-07-01 by Clay

介紹

之前我閱讀過許多不同的 Speculative Decoding 加速推理技巧,也嘗試使用 PyTorch 實現了幾種不同的架構,包括模型架構、訓練與推理等腳本(fast-llm-inference),這一次當然又是新的目標。

一兩個月前我閱讀了 Hydra —— 我們可以把它視為一種基於 Medusa 架構的變體,並且這一次,我希望能夠基於 TensorRT-LLM 這個加速推理框架,將載入官方 Hydra 權重的模型支援在上面跑,權當作我個人的一個小小 Side Project。

目標:在 TensorRT-LLM 上、Python Session、並且要讓 Hydra Heads 產生合理的 draft tokens 讓本身模型驗證。(如果想要看我實現的 branch,目前尚未跟 TensorRT-LLM 有過任何討論,所以只是單純放在我的 GitHub 上:https://github.com/ccs96307/TensorRT-LLM/tree/support-spec-decode-hydra/examples/hydra


Hydra 回顧

關於 Hydra 更多的解釋,可以參考我之前寫過的筆記:[論文閱讀] Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding

簡單來說,Hydra 跟 Medusa 是非常相似的架構,同樣是每一個 Heads 負責解碼不同時間點的生成結果;但是 Hydra 更是引入了每一個時間點之前的、前一個 Head 所解碼出的 Token 資訊,進一步提昇了 Hydra Heads 的接受率。

可以想像成,本來的 Medusa 是在看不到前一個 Head 生成的確切結果、只看得到相對模糊的 hidden states 進行預測的、但是 Hydra 會明確告訴準備進行生成的 Head,前一個生成的 Token 是什麼(藉由拼接 token embedding 的形式)。

閱讀官方實現的原始碼(https://github.com/zankner/Hydra/blob/main/hydra/model/hydra_heads/prefix_mlp_head.py),我們可以看到其實現:

self.hydra_head = HydraMLP(
    hydra_num_layers=self.hydra_num_layers,
    hydra_num_heads=self.hydra,
    grounded_heads=self.grounded_heads,
    input_embed_fn=self.base_model.model.embed_tokens,
    base_config=self.config,
    lm_head_init_weight=base_model.lm_head.weight.data
)
self.hydra_lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 


其中除了要指定層數與頭數外,也要把 input_embed_fn 放進去;一般來說就是放 model 的 embedding layer。

在 HydraMLP 的前向 forward() 中,我們可以很清晰地 Hydra Heads 的運作方式:

def forward(self, base_hidden_states, input_ids=None, noise=None):
    """
    Forward pass of the MLP.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output after the MLP.
    """

    hydra_hidden_states = []
    if self.grounded_heads:
        assert input_ids is not None, "Input ids must be provided for grounded heads"
        with torch.inference_mode():
            input_embeds = self.input_embed_fn(input_ids)
        if noise is not None:
            input_embeds = input_embeds + noise
        hydra_inputs = [base_hidden_states]
        for i in range(self.hydra_num_heads):
            # Move input embeddings back one spot for each hydra head idx
            hydra_inputs.append(torch.roll(input_embeds, shifts=-(i+1), dims=1))
        
        for i in range(self.hydra_num_heads):
            head_input = torch.cat(hydra_inputs[:i + 2], dim=-1) 
            hydra_hidden_states.append(self.hydra_mlp[i](head_input))
    else:
        for i in range(self.hydra_num_heads):
            hydra_hidden_states.append(self.hydra_mlp[i](base_hidden_states))


也可以很明確地看到 Token 轉換成 Embeddings 的過程:

with torch.inference_mode():
    input_embeds = self.input_embed_fn(input_ids)


研究團隊進行實驗,表明這比 Medusa 快了 1.1 倍。

現在,我要將其在 TensorRT-LLM 上支援、並跑起來。


實作概念與過程

那麼,說要在 TensorRT-LLM 上跑起來,具體需要做什麼事情呢?

首先我先定義了自己的目標,因為我是想要支援官方訓練出來的 Hydra 模型,所以我先載好了本體的 vicuda-7b 模型,又再次把官方 GitHub 的 ankner--hydra-vicuna-7b-v1.3 給下載了下來。

下載下來後閱讀了 config.json,可以清楚知道這是屬於 Hydra 定義的多種 head 架構中的 “prefix-mlp”。也是這個架構,為我後面添了不少麻煩 XD

原本的 Medusa Heads 是平行的,模型推理到底端的 hidden_states 可以直接傳遞給每個 Heads 進行推理;然而 Hydra 不是,它首先進入了僅有一層的 Llama Model,接著與模型本來的解碼頭機率分佈最高的 Token 轉換成的 embedding 向量拼接,再傳遞進 Hydra Heads。

並且 Hydra Heads 並不是平行的,而是序列推理的。Head 1 推理出的 Token 1 將會被 embedding layer 轉換成 Token Embedding 1 並與之前的特徵維度拼接傳遞給 Head 2…… 依此類推。

所以 Hydra Heads 的輸入維度會是呈現:8192, 12288, 16384, 20480… 這樣的擴展。這部份的麻煩等下細說。


再來我們需要知道的是,TensorRT-LLM 有分成 Python 和 C++ 兩個後端,而我此次就是來實現 Python 後端的部份;另外,TensorRT-LLM 本身就已經有支援 Medusa 了 —— 所以既然是跟 Medusa 架構很像的 Hydra,基本上可以模仿 Medusa 去重新定義架構。

所以我做的事情基本如下:

  • tensorrt_llm/models/ 底下重新定義了 hydra 資料夾(從 medusa/ copy 而來)
  • 重新定義了 model.py, weight.py, config.py 檔案(https://github.com/ccs96307/TensorRT-LLM/tree/support-spec-decode-hydra/tensorrt_llm/models/hydra
  • 為了驗證,建立 examples/hydra/,並寫出 convert.sh, build.sh, run.sh 等三份腳本,確認模型架構真的能跑通
  • 同步修正 examples/run.py 以及 tensorrt_llm/runtime/generation.py 中的 Hydra 支援,這邊比較支微末節,大致上就是按照 Medusa 既有的流程去改,只是要改成在傳入 --speculative_decoding_mode hydra 要一路觸發我們更新的 Hydra 分支去走

examples/hydra/convert.sh, build.sh, run.sh

首先是第一個該執行的腳本:

#!/bin/bash


python convert_checkpoint.py \
    --model_dir ./lmsys--vicuna-7b-v1.3/ \
    --hydra_model_dir ./ankner--hydra-vicuna-7b-v1.3 \
    --output_dir ./tllm_checkpoint_1gpu_hydra \
    --dtype float16 \
    --num_hydra_heads 4 \
    --num_hydra_layers 4


這個腳本在做的事情是:將一個預訓練好的大型語言模型(如 LLaMA、Qwen2)轉換成 TensorRT-LLM 專用的、高度優化的檢查點(checkpoint)格式,以便後續能建構成為極速推論的 TensorRT 引擎。

#!/bin/bash


trtllm-build \
    --checkpoint_dir tllm_checkpoint_1gpu_hydra \
    --output_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
    --gemm_plugin float16 \
    --speculative_decoding_mode hydra \
    --max_batch_size 4


這段 shell 腳本是使用 TensorRT-LLM 部署一個啟用 Hydra 技術的大型語言模型的最關鍵的一個步驟

它的核心作用是:讀取我們用前一個 Python 腳本所建立的檢查點 (checkpoint),並從中建構 (Build) 出一個高度優化、可直接運行的 TensorRT 引擎 (engine)

如果我們的模型定義沒成功,hydra/weight.py 中讀取 Hydra Heads 的權重有問題,就會一直在這一步報錯;對了,還有定義的模型架構其計算圖推理不出維度,也會在這一步驟報錯。

#!/bin/bash



python ../run.py \
    --engine_dir ./tmp/hydra/7B/trt_engine/fp16/1-gpu/ \
    --use_py_session \
    --tokenizer_dir ./lmsys--vicuna-7b-v1.3/ \
    --max_output_len=100 \
    --hydra_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
    --temperature 1.0 \
    --input_text "Once upon" \
    --debug_mode


到這一步就相對簡單了,就是把我們的 input_text 傳入,測試 Hydra 生成的效果。

問題點

修改 TensorRT-LLM 使其支援 Hydra 的過程真的碰到許多坑,下面揀幾個非常痛苦、一開始甚至不知道是否可以解決的問題來紀錄。

完成之後再來細想,會覺得主要問題在於:

  1. 拿不到 hydra heads 產生的 logits:我一開始發現我的 hydra 接受率是徹頭徹尾的 0!結果是因為我發現是在模型定義的 model.py 中,一定要寫 medusa_logits.mark_output('medusa_logits', self.config.logits_dtype) 這樣的宣告。這是讓優化的後端一定要保留這個變數,不要將其資源釋放掉。另外,我發現另外取名字會拿不到資料,大概是 Kernel 本身就寫死了 medusa_logits 這個變數?所以我就入境隨俗
  2. CUDA illegal memory access:我 debug 了好久,一開始怎樣都跑不起來。最後覺得不能浪費一整個禮拜的努力,一咬牙每個環節都嘗試註解掉看看,最後終於定位到是 Hydra 的 prefix_embedding_layer 造成的問題。在原始實現中,這是一個不會走到 embedding layer 的『一層 Llama 模型』(因為只傳遞了最後的 hidden_states 進去,所以不會觸發 input_ids 進入 embedding layer 的情況)。我本來嘗試直接呼叫 TensorRT-LLM 中的 Llama 模型,但發現無論是直接使用模型本體的 AttentionParams 還是我自定義的,都會讓其在跑的時候同樣出現 CUDA Error,最後乾脆把 Transformers 實現的 LlamaModel 搬過來,然後把所有內部操作都換成 TensorRT-LLM 版本(這樣才有計算圖優化)
  3. AssertionError: tensor /concat_L2636/CONCATENATION_1_output_0 has an invalid shape:這個問題同樣困擾我好久,最後網路上查了老半天,這才發現有人說這是 TensorRT 無法自動推理正確的維度。確實,Hydra Heads 越往後其維度就越大,還需要與之前的 Token embedding 做拼接(參考上一節的 Hydra 架構圖),自動化推理沒辦法理解動態調整好像也有道理 —— 所以最後我就手動展開迴圈了。但這是我認為我修改中最糟糕的改動,萬一不同的 Hydra 有著不同的 Hydra Heads 呢?所以未來我得找到一個方法支援動態調整 head num。

所以最後我的 model.py 模型架構實現如下:

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
from typing import Optional, Union

import numpy as np
import torch
from transformers import AutoModelForCausalLM

from tensorrt_llm._utils import numpy_to_torch
from tensorrt_llm.models.hydra.weight import load_hydra_hf
from tensorrt_llm.models.llama.model import LLaMAForCausalLM, RmsNorm
from tensorrt_llm.models.qwen.model import QWenForCausalLM

from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import (ACT2FN, add, cast, concat, constant, cos, div,
                           expand, matmul, mul, shape, sin, slice, softmax,
                           squeeze, stack, topk, transpose, unsqueeze, view)
from ...layers import ColumnLinear
from ...mapping import Mapping
from ...module import Module, ModuleList
from ..modeling_utils import PretrainedModel, QuantConfig
from .config import HydraConfig
from .weight import convert_hf_llama


# refer: https://github.com/zankner/Hydra/blob/main/hydra/model/hydra_heads/prefix_mlp_head.py#L44
class HydraResBlock(Module):

    def __init__(
            self,
            hidden_size,
            hidden_act="silu",
            num_condition=0,
            dtype=None,
            mapping=Mapping(),
    ):
        super().__init__()

        input_size = hidden_size * (num_condition + 1)
        self.linear = ColumnLinear(input_size,
                                   hidden_size,
                                   dtype=dtype,
                                   tp_group=mapping.tp_group,
                                   tp_size=mapping.tp_size,
                                   gather_output=True)
        self.res_connection = ColumnLinear(
            input_size,
            hidden_size,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True) if num_condition > 0 else torch.nn.Identity()

        self.hidden_act = hidden_act

    def forward(self, x):
        return self.res_connection(x) + ACT2FN[self.hidden_act](self.linear(x))


class HydraPrefixMLP(Module):

    def __init__(
            self,
            num_layers,
            hidden_size,
            vocab_size,
            hydra_head_idx,
            hidden_act="silu",
            dtype=None,
            mapping=Mapping(),
            lm_head_init_weight=None,
    ):
        super().__init__()
        self.hydra_mlp = HydraResBlock(hidden_size=hidden_size,
                                       num_condition=hydra_head_idx + 1,
                                       hidden_act=hidden_act,
                                       dtype=dtype,
                                       mapping=mapping)

        self.hydra_mlps = ModuleList([
            HydraResBlock(hidden_size=hidden_size,
                          hidden_act=hidden_act,
                          dtype=dtype,
                          mapping=mapping) for _ in range(num_layers)
        ])
        self.hydra_lm_head = ColumnLinear(hidden_size,
                                          vocab_size,
                                          bias=True,
                                          dtype=dtype,
                                          tp_group=mapping.tp_group,
                                          tp_size=mapping.tp_size,
                                          gather_output=True)

    def forward(self, x):
        hidden_states = self.hydra_mlp(x)

        for layer in self.hydra_mlps:
            hidden_states = layer(hidden_states)

        return self.hydra_lm_head(hidden_states)


def _compute_default_rope_parameters(
    config: Optional[HydraConfig] = None,
    **rope_kwargs,
):
    # if len(rope_kwargs) > 0:
    #     base = rope_kwargs["base"]
    #     dim = rope_kwargs["dim"]
    # elif config is not None:
    #     base = config.rope_theta
    #     partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
    #     head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
    #     dim = int(head_dim * partial_rotary_factor)

    base = getattr(config, "rope_theta", 10000.0)
    head_dim = config.hidden_size // config.num_attention_heads
    partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
    dim = int(head_dim * partial_rotary_factor)

    attention_factor = 1.0  # Unused in this type of RoPE

    # Compute the inverse frequencies
    idx = np.arange(0, dim, 2, dtype=np.float32)
    inv_freq = 1.0 / (base**(idx / dim))

    return inv_freq, attention_factor


def _compute_llama3_parameters(
    config: HydraConfig,
    **rope_kwargs,
):
    # Gets the default RoPE parameters
    inv_freq, attention_factor = _compute_default_rope_parameters(
        config, **rope_kwargs)

    factor = 8.0
    low_freq_factor = 1.0
    high_freq_factor = 4.0
    old_context_len = 8192.0

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * math.pi / inv_freq

    # Use numpy
    inv_freq_llama = np.where(wavelen > low_freq_wavelen, inv_freq / factor,
                              inv_freq)

    smooth_factor = (old_context_len / wavelen -
                     low_freq_factor) / (high_freq_factor - low_freq_factor)
    smoothed_inv_freq = (
        1 - smooth_factor
    ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama

    is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen
                                                      >= high_freq_wavelen)

    inv_freq_llama = np.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

    return inv_freq_llama, attention_factor


def rotate_half(x):
    """
    TensorRT-LLM functional version of rotate_half.
    Assumes x is a 4D tensor: [batch, num_heads, seq_len, head_dim]
    Splits the last dimension in half, rotates the halves, and concatenates them.
    """
    # Get dimensions as scalar tensors
    dim0 = squeeze(shape(x, 0), dim=0)
    dim1 = squeeze(shape(x, 1), dim=0)
    dim2 = squeeze(shape(x, 2), dim=0)
    last_dim = squeeze(shape(x, 3), dim=0)

    # Compute half of last_dim
    two = constant(np.array([2], dtype="int64"))
    half_dim = squeeze(div(last_dim, two), dim=0)

    # Create scalar zero for use in starts
    zero = constant(np.array(0, dtype="int64"))

    # Define starts and sizes for slicing
    starts1 = stack([zero, zero, zero, zero], dim=0)
    sizes1 = stack([dim0, dim1, dim2, half_dim], dim=0)
    starts2 = stack([zero, zero, zero, half_dim], dim=0)

    # Slice tensors into two halves along the last dimension
    x1 = slice(x, starts=starts1, sizes=sizes1)
    x2 = slice(x, starts=starts2, sizes=sizes1)

    # Negate the second half and concatenate
    neg_x2 = mul(x2, -1.0)
    return concat([neg_x2, x1], dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """
    TensorRT-LLM functional version of apply_rotary_pos_emb.
    """
    # PyTorch: cos.unsqueeze(unsqueeze_dim)
    cos_expanded = unsqueeze(cos, unsqueeze_dim)
    sin_expanded = unsqueeze(sin, unsqueeze_dim)

    # PyTorch: (q * cos) + (rotate_half(q) * sin)
    rotated_q = rotate_half(q)
    q_embed = add(mul(q, cos_expanded), mul(rotated_q, sin_expanded))

    # PyTorch: (k * cos) + (rotate_half(k) * sin)
    rotated_k = rotate_half(k)
    k_embed = add(mul(k, cos_expanded), mul(rotated_k, sin_expanded))

    return q_embed, k_embed


def repeat_kv(hidden_states, n_rep: int):
    """
    TensorRT-LLM functional version of repeat_kv
    """
    if n_rep == 1:
        return hidden_states

    batch, num_key_value_heads, slen, head_dim = (shape(hidden_states, 0),
                                                  shape(hidden_states, 1),
                                                  shape(hidden_states, 2),
                                                  shape(hidden_states, 3))

    hidden_states_unsqueezed = unsqueeze(hidden_states, 2)
    hidden_states_expanded = expand(
        hidden_states_unsqueezed,
        [batch, num_key_value_heads, n_rep, slen, head_dim])

    final_shape = [batch, num_key_value_heads * n_rep, slen, head_dim]
    return view(hidden_states_expanded, final_shape)


def eager_attention_forward(
    query,
    key,
    value,
    num_key_value_groups: int,
    scaling: float,
    dropout: float = 0.0,
    attention_mask=None,
    **kwargs,
):
    key_states = repeat_kv(key, num_key_value_groups)
    value_states = repeat_kv(value, num_key_value_groups)

    # Attetion Scores: (Q @ K.T) * scaling
    key_states_T = transpose(key_states, 2, 3)
    attn_scores = matmul(query, key_states_T)
    attn_scores_scaled = mul(attn_scores, scaling)

    if attention_mask is not None:
        key_len = shape(key_states, 2)
        mask_shape = shape(attention_mask)

        causal_mask = slice(
            attention_mask,
            starts=[0, 0, 0, 0],
            sizes=[mask_shape[0], mask_shape[1], mask_shape[2], key_len])
        attn_scores_masked = add(attn_scores_scaled, causal_mask)
    else:
        attn_scores_masked = attn_scores_scaled

    # Softmax
    query_dtype = query.dtype
    attn_weights_fp32 = softmax(cast(attn_scores_masked, "float32"), dim=-1)
    attn_weights = cast(attn_weights_fp32, query_dtype)

    # Ignore dropout
    # Compute Attention Output: attn_weights @ V
    attn_output = matmul(attn_weights, value_states)

    # Transpose
    attn_output = transpose(attn_output, 1, 2)

    return attn_output


class LlamaRotaryEmbedding(Module):

    def __init__(self, config: HydraConfig, mapping=Mapping()):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings
        self.config = config
        self.rope_init_fn = _compute_llama3_parameters

        self.original_inv_freq, self.attention_scaling = self.rope_init_fn(
            self.config)

    def forward(self, x, position_ids):
        x_dtype = x.dtype

        inv_freq = constant(self.original_inv_freq)

        inv_freq_unsqueezed = unsqueeze(unsqueeze(inv_freq, 0), 2)
        b = shape(position_ids, 0)
        inv_freq_dim0 = shape(inv_freq, 0)
        one = constant(np.array([1], dtype=np.int64))

        expand_shape = concat(
            [unsqueeze(b, 0), unsqueeze(inv_freq_dim0, 0), one], dim=0)

        inv_freq_expanded = expand(inv_freq_unsqueezed, expand_shape)
        position_ids_expanded = unsqueeze(position_ids, 1)

        inv_freq_float32 = cast(inv_freq_expanded, "float32")
        position_ids_float32 = cast(position_ids_expanded, "float32")

        freqs_t = matmul(inv_freq_float32, position_ids_float32)
        freqs = transpose(freqs_t, 1, 2)

        emb = concat([freqs, freqs], dim=-1)

        # 6. Apply cos, sin, and scaling
        # PyTorch: emb.cos() * self.attention_scaling
        cos_emb = cos(emb)
        cos_scaled = mul(cos_emb, self.attention_scaling)

        # PyTorch: emb.sin() * self.attention_scaling
        sin_emb = sin(emb)
        sin_scaled = mul(sin_emb, self.attention_scaling)

        # 7. Cast back to original dtype
        # PyTorch: .to(dtype=x.dtype)
        final_cos = cast(cos_scaled, x_dtype)
        final_sin = cast(sin_scaled, x_dtype)

        return final_cos, final_sin


class LlamaMLP(Module):

    def __init__(self, config, dtype=None, mapping=Mapping()):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = ColumnLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.up_proj = ColumnLinear(
            in_features=self.hidden_size,
            out_features=self.intermediate_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.down_proj = ColumnLinear(
            in_features=self.intermediate_size,
            out_features=self.hidden_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # PyTorch: self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        gated_x = self.act_fn(self.gate_proj(x))
        up_x = self.up_proj(x)
        fused_x = mul(gated_x, up_x)

        down_proj = self.down_proj(fused_x)
        return down_proj


class LlamaAttention(Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self,
                 config: HydraConfig,
                 layer_idx: int,
                 mapping=Mapping(),
                 dtype=None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(
            config, "head_dim",
            config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = True

        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.tp_size = mapping.tp_size
        self.hidden_size = config.hidden_size

        self.q_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_attention_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.k_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.v_proj = ColumnLinear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_dim,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )
        self.o_proj = ColumnLinear(
            in_features=config.num_attention_heads * self.head_dim,
            out_features=config.hidden_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

    def forward(
        self,
        hidden_states,
        position_embeddings,
        attention_mask=None,
        past_key_value=None,
        cache_position=None,
    ):
        b, s = shape(hidden_states, 0), shape(hidden_states, 1)

        # 1. Q, K, V Projections
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # 2. Reshape and Transpose to [batch, num_heads, seq_len, head_dim]
        # PyTorch: .view(hidden_shape).transpose(1, 2)
        query_states = transpose(
            view(
                q,
                [0, 0, self.num_attention_heads // self.tp_size, self.head_dim
                 ]), 1, 2)
        key_states = transpose(
            view(
                k,
                [0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
                 ]), 1, 2)
        value_states = transpose(
            view(
                v,
                [0, 0, self.num_key_value_heads // self.tp_size, self.head_dim
                 ]), 1, 2)

        # 3. Apply Rotary Position Embedding
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states,
                                                        key_states, cos, sin)
        # 4. Maybe pass it...
        # if past_key_value is not None:
        #     # sin and cos are specific to RoPE models; cache_position needed for the static cache
        #     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        #     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 5. Attention Computation
        attn_output = eager_attention_forward(
            query=query_states,
            key=key_states,
            value=value_states,
            num_key_value_groups=self.num_key_value_groups,
            scaling=self.scaling,
            attention_mask=attention_mask,
        )

        # 6. Final Reshape and Projection
        # PyTorch: attn_output.reshape(*input_shape, -1).contiguous()

        attn_output = view(attn_output, [0, 0, -1])
        attn_output = self.o_proj(attn_output)

        return attn_output


class LlamaDecoderLayer(Module):

    def __init__(self, config: HydraConfig, layer_idx: int, mapping=Mapping()):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaAttention(config=config,
                                        layer_idx=layer_idx,
                                        mapping=mapping)

        self.mlp = LlamaMLP(config, mapping=mapping)
        self.input_layernorm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )
        self.post_attention_layernorm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )

    def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_ids=None,
            past_key_value=None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
            cache_position=None,
            position_embeddings=None,  # necessary, but kept here for BC
    ):
        residual = hidden_states
        normed_hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        attn_output = self.self_attn(
            hidden_states=normed_hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
        )
        hidden_states = add(residual, attn_output)

        # Fully Connected
        residual = hidden_states
        normed_hidden_states = self.post_attention_layernorm(hidden_states)

        mlp_output = self.mlp(normed_hidden_states)
        hidden_states = add(residual, mlp_output)

        return hidden_states


# refer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class PrefixEmbeddingLayer(Module):

    def __init__(self, config: HydraConfig, mapping=Mapping()):
        super().__init__()
        self.vocab_size = config.vocab_size

        self.layer = LlamaDecoderLayer(config=config,
                                       layer_idx=0,
                                       mapping=mapping)

        self.norm = RmsNorm(
            normalized_shape=config.hidden_size,
            dtype=config.dtype,
        )
        self.rotary_emb = LlamaRotaryEmbedding(config=config, mapping=mapping)

    def forward(
        self,
        inputs_embeds,
        position_ids,
        attention_mask=None,
    ):
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        hidden_states = self.layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=False,
            position_embeddings=position_embeddings,
        )

        hidden_states = self.norm(hidden_states)

        return hidden_states


# HydraForCausalLM is a thin wrapper that picks parent class for GenericHydraForCausalLM.
# All hydra functionality is defined in GenericHydraForCausalLM.
class HydraForCausalLM(PretrainedModel):
    config_class = HydraConfig

    def __init__(self, config: HydraConfig):
        super().__init__(config)

        BaseLM = QWenForCausalLM if hasattr(
            config,
            "model_type") and "qwen" in config.model_type else LLaMAForCausalLM

        class GenericHydraForCausalLM(BaseLM):

            def __init__(self, config: HydraConfig):
                super().__init__(config)
                self.num_hydra_heads = config.num_hydra_heads
                self.num_hydra_layers = config.num_hydra_layers
                self.hidden_size = config.hidden_size
                self.vocab_size = config.vocab_size
                vocab_size_padded = pad_vocab_size(self.vocab_size,
                                                   config.mapping.tp_size)

                base_kwargs = config.to_dict()
                prefix_config = BaseLM.config_class(**base_kwargs)
                self.prefix_embedding_layer = PrefixEmbeddingLayer(
                    prefix_config)

                self.hydra_heads = ModuleList([
                    HydraPrefixMLP(num_layers=self.num_hydra_layers - 1,
                                   hidden_size=config.hidden_size,
                                   vocab_size=vocab_size_padded,
                                   hydra_head_idx=i,
                                   hidden_act=config.hidden_act,
                                   dtype=config.dtype,
                                   mapping=config.mapping)
                    for i in range(self.num_hydra_heads)
                ])

                self.input_embed_fn = self.transformer.vocab_embedding
                self.max_hydra_token_len = config.max_draft_len

            def forward(self, *args, **kwargs):
                output_original = True
                hidden_states = super().forward(*args, **kwargs)

                if kwargs['use_cache']:
                    if default_net().plugin_config.paged_kv_cache:
                        lm_logits, hidden_states, _ = hidden_states
                    else:
                        lm_logits, presents, hidden_states = hidden_states

                if self.mapping.is_last_pp_rank():
                    position_ids = kwargs["position_ids"]

                    hidden_states_3d = unsqueeze(
                        hidden_states, 1)  # Shape: [B, H] -> [B, 1, H]

                    prefix_embedding = self.prefix_embedding_layer(
                        inputs_embeds=hidden_states_3d,
                        position_ids=position_ids,
                        attention_mask=None,
                    )

                    _, topk_ids = topk(lm_logits, k=1, dim=-1)
                    next_embedding = self.input_embed_fn(squeeze(topk_ids, -1))

                    # TODO: Need to convert back into for-loop
                    # prefix_embedding and next_embedding are 2D: [batch, hidden_size]
                    # prefix_embedding_3d = unsqueeze(prefix_embedding, 1) # -> [batch, 1, hidden_size]
                    next_embedding_3d = unsqueeze(
                        next_embedding, 1)  # -> [batch, 1, hidden_size]

                    all_head_logits = []

                    # --- Head 0 ---
                    head_0_input = concat([prefix_embedding, next_embedding_3d],
                                          dim=2)
                    # head_0_input = concat([next_embedding_3d, next_embedding_3d], dim=2)
                    head_0_logits = self.hydra_heads[0](head_0_input)
                    all_head_logits.append(squeeze(head_0_logits, dim=1))

                    # --- Head 1 ---
                    _, next_token_ids_1 = topk(head_0_logits, k=1, dim=-1)
                    next_embedding_1 = self.input_embed_fn(
                        squeeze(next_token_ids_1, -1))

                    head_1_input = concat([head_0_input, next_embedding_1],
                                          dim=2)
                    head_1_logits = self.hydra_heads[1](head_1_input)
                    all_head_logits.append(squeeze(head_1_logits, dim=1))

                    # --- Head 2 ---
                    _, next_token_ids_2 = topk(head_1_logits, k=1, dim=-1)
                    next_embedding_2 = self.input_embed_fn(
                        squeeze(next_token_ids_2, -1))

                    head_2_input = concat([head_1_input, next_embedding_2],
                                          dim=2)
                    head_2_logits = self.hydra_heads[2](head_2_input)
                    all_head_logits.append(squeeze(head_2_logits, dim=1))

                    # --- Head 3 ---
                    _, next_token_ids_3 = topk(head_2_logits, k=1, dim=-1)
                    next_embedding_3 = self.input_embed_fn(
                        squeeze(next_token_ids_3, -1))

                    head_3_input = concat([head_2_input, next_embedding_3],
                                          dim=2)
                    head_3_logits = self.hydra_heads[3](head_3_input)
                    all_head_logits.append(squeeze(head_3_logits, dim=1))

                    medusa_logits = stack(all_head_logits, dim=0)
                    medusa_logits.mark_output('medusa_logits',
                                              self.config.logits_dtype)

                else:
                    hidden_states.mark_output('hidden_states_output',
                                              self.config.dtype)

                if kwargs['use_cache'] and default_net(
                ).plugin_config.paged_kv_cache == False:
                    if self.mapping.is_last_pp_rank():
                        if output_original:
                            return (medusa_logits, lm_logits, presents)
                        return (medusa_logits, presents)
                    return (hidden_states, presents)
                else:
                    if self.mapping.is_last_pp_rank():
                        if output_original:
                            return medusa_logits, lm_logits
                        return medusa_logits
                    return hidden_states

            def prepare_inputs(self, *args, **kwargs):
                kwargs['speculative_decoding_draft_tokens_external'] = False
                kwargs['max_draft_len'] = self.max_hydra_token_len
                return super().prepare_inputs(*args, **kwargs)

        self.model = GenericHydraForCausalLM(config)

    # Specialization to redirect accesses to self.model
    def __getattribute__(self, name):
        if name == 'model' or '__' in name:
            return object.__getattribute__(self, name)
        else:
            model = object.__getattribute__(self, 'model')
            return model.__getattribute__(name)

    # Override specialized __setattr__ defined in Module
    def __setattr__(self, name, value) -> None:
        object.__setattr__(self, name, value)

    @classmethod
    def from_hugging_face(
            cls,
            hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
            dtype: str = 'auto',
            mapping: Optional[Mapping] = None,
            quant_config: Optional[QuantConfig] = None,
            **kwargs):
        import transformers

        assert hf_model_or_dir is not None
        speculative_model_dir = kwargs.get('speculative_model', None)

        use_preloading = isinstance(hf_model_or_dir,
                                    transformers.PreTrainedModel)
        if use_preloading:
            hf_model = hf_model_or_dir
            hf_config_or_dir = hf_model.config
        else:
            hf_model_dir = hf_model_or_dir
            hf_config_or_dir = hf_model_or_dir

        config = HydraConfig.from_hugging_face(hf_config_or_dir,
                                               dtype=dtype,
                                               mapping=mapping,
                                               quant_config=quant_config,
                                               **kwargs)

        # ModelOpt ckpt has combined base model and Hydra-head
        is_modelopt_ckpt = True if not speculative_model_dir else False

        if not use_preloading:
            trust_remote_code = kwargs.pop('trust_remote_code', True)

            if is_modelopt_ckpt:
                hf_model = LLaMAForCausalLM.from_hugging_face(
                    hf_model_dir,
                    dtype,
                    mapping=mapping,
                    quant_config=quant_config,
                    **kwargs)
            else:
                hf_model = AutoModelForCausalLM.from_pretrained(
                    hf_model_dir,
                    torch_dtype="auto",
                    trust_remote_code=trust_remote_code)

                assert isinstance(hf_model, transformers.PreTrainedModel)

        if is_modelopt_ckpt:
            weights = {
                name: numpy_to_torch(param.raw_value)
                for name, param in hf_model.named_parameters()
            }
        else:
            weights = convert_hf_llama(
                hf_model,
                config.mapping,
                dtype='float16',
                use_parallel_embedding=config.use_parallel_embedding)

        model = cls(config)

        if is_modelopt_ckpt:
            num_hydra_heads = config.config.num_hydra_heads
            num_hydra_layers = config.config.num_hydra_layers
            speculative_model_dir = hf_model_or_dir
        else:
            config_file = speculative_model_dir / "config.json"
            with open(config_file) as fp:
                model_config = json.load(fp)

            num_hydra_heads = kwargs[
                'speculative_config'].num_hydra_heads if 'speculative_config' in kwargs else model_config.get(
                    'hydra_num_heads', None)
            num_hydra_layers = model_config.get('hydra_num_layers', None)
        hydra_weights = load_hydra_hf(hydra_path=speculative_model_dir,
                                      num_hydra_heads=num_hydra_heads,
                                      num_hydra_layers=num_hydra_layers,
                                      mapping=mapping,
                                      dtype="float16",
                                      base_config=hf_model.config,
                                      is_modelopt_ckpt=is_modelopt_ckpt)
        weights.update(hydra_weights)
        model.load(weights)
        return model


可以看到大部分其實都是 PrefixEmbeddingLayer 的模型定義,但好險自己定義的模型總算不會出現 CUDA 非法存取記憶體了…

另外也一起附上我讀取 hydra head 權重的原始碼:

def load_hydra_hf(hydra_path: str,
                  num_hydra_heads: int,
                  num_hydra_layers: int,
                  base_config: PretrainedConfig,
                  mapping=Mapping(),
                  dtype='float32',
                  use_weight_only=False,
                  plugin_weight_only_quant_type=None,
                  is_modelopt_ckpt=False):
    if is_modelopt_ckpt:
        from safetensors.torch import load_file
        state_dict = {}
        for filename in sorted(Path(hydra_path).glob("*.safetensors")):
            print(f"Loading the weights of Hydra heads from {filename}")
            state_dict.update(load_file(filename))
    else:
        is_ckpt_safetensors = False

        ckpt_file = Path(hydra_path) / "hydra_lm_head.pt"
        if not ckpt_file.exists():
            ckpt_file = Path(hydra_path) / "hydra_lm_head.safetensors"
            is_ckpt_safetensors = True

        if is_ckpt_safetensors:
            logger.INFO("Safetensors Found ...")
            from safetensors.torch import load_file
            state_dict = load_file(ckpt_file)
        else:
            state_dict = torch.load(ckpt_file, map_location="cpu")

    torch_dtype = str_dtype_to_torch(dtype)
    weights = {}

    # Embedding
    # embedding_weight = state_dict["prefix_embeding_layer.embed_tokens.weight"].clone().to(torch_dtype)
    # split_emb = split(embedding_weight, mapping.tp_size, mapping.tp_rank, dim=0)
    # weights["prefix_embedding_layer.layer.vocab_embedding.weight"] = split_emb

    # Attention (QKV, O)
    q_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.q_proj.weight"].clone().to(
            torch_dtype)
    k_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.k_proj.weight"].clone().to(
            torch_dtype)
    v_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.v_proj.weight"].clone().to(
            torch_dtype)
    o_w = state_dict[
        f"prefix_embeding_layer.layers.0.self_attn.o_proj.weight"].clone().to(
            torch_dtype)

    weights[
        f"prefix_embedding_layer.layer.self_attn.q_proj.weight"] = split_matrix_tp(
            q_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.k_proj.weight"] = split_matrix_tp(
            k_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.v_proj.weight"] = split_matrix_tp(
            v_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.self_attn.o_proj.weight"] = split_matrix_tp(
            o_w,
            mapping.tp_size,
            mapping.tp_rank,
            dim=1,
        )

    # MLP (fc, gate, proj)
    weights[
        f"prefix_embedding_layer.layer.mlp.gate_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.gate_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.mlp.up_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.up_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=0,
        )
    weights[
        f"prefix_embedding_layer.layer.mlp.down_proj.weight"] = split_matrix_tp(
            state_dict[f"prefix_embeding_layer.layers.0.mlp.down_proj.weight"].
            clone().to(torch_dtype),
            mapping.tp_size,
            mapping.tp_rank,
            dim=1,
        )

    # LayerNorm (no need to split)
    weights[
        f"prefix_embedding_layer.layer.input_layernorm.weight"] = state_dict[
            f"prefix_embeding_layer.layers.0.input_layernorm.weight"].clone(
            ).to(torch_dtype)
    weights[f"prefix_embedding_layer.layer.post_attention_layernorm.weight"] = state_dict[
        f"prefix_embeding_layer.layers.0.post_attention_layernorm.weight"].clone(
        ).to(torch_dtype)
    weights[f"prefix_embedding_layer.norm.weight"] = state_dict[
        f"prefix_embeding_layer.norm.weight"].clone().to(torch_dtype)

    # Load Hydra heads weights
    for i in range(num_hydra_heads):
        w = state_dict[f"hydra_mlp.{i}.1.linear.weight"].clone().to(torch_dtype)
        weights[f"hydra_heads.{i}.hydra_mlp.linear.weight"] = split(
            w, mapping.tp_size, mapping.tp_rank, dim=0)
        weights[f"hydra_heads.{i}.hydra_mlp.linear.bias"] = state_dict[
            f"hydra_mlp.{i}.1.linear.bias"].clone().to(torch_dtype)

        # res_connection weights
        w_res = state_dict[f"hydra_mlp.{i}.1.res_connection.weight"].clone().to(
            torch_dtype)
        weights[f"hydra_heads.{i}.hydra_mlp.res_connection.weight"] = split(
            w_res, mapping.tp_size, mapping.tp_rank, dim=0)
        weights[f"hydra_heads.{i}.hydra_mlp.res_connection.bias"] = state_dict[
            f"hydra_mlp.{i}.1.res_connection.bias"].clone().to(torch_dtype)

        for l_idx in range(num_hydra_layers - 1):
            seq_idx = 3 + 2 * l_idx  # 3, 5, 7, 9...
            w = state_dict[f"hydra_mlp.{i}.{seq_idx}.linear.weight"].clone().to(
                torch_dtype)
            weights[
                f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.weight"] = split(
                    w, mapping.tp_size, mapping.tp_rank, dim=0)
            weights[
                f"hydra_heads.{i}.hydra_mlps.{l_idx}.linear.bias"] = state_dict[
                    f"hydra_mlp.{i}.{seq_idx}.linear.bias"].clone().to(
                        torch_dtype)

        # Load lm_head
        w_lm = state_dict[f"hydra_lm_head.{i}.1.weight"].clone().to(torch_dtype)
        weights[f"hydra_heads.{i}.hydra_lm_head.weight"] = split(
            w_lm, mapping.tp_size, mapping.tp_rank, dim=0)

        if f"hydra_lm_head.{i}.1.bias" in state_dict:
            weights[f"hydra_heads.{i}.hydra_lm_head.bias"] = state_dict[
                f"hydra_lm_head.{i}.1.bias"].clone().to(torch_dtype)

    return weights

測試結果

首先我的機器配置如下:

  • RTX 4090 x 1 (VRAM 24GB)
  • RAM 64GB
  • CPU Intel I7 13th

測試的配置是同樣的一句話 “Who are you?” 然後都是以 batch_size=1 連續測試 10 次取平均值(run.py 內的 --run_profiling 參數開啟自動測試)

這是一個非常粗略的測試,但可以看到,雖然我實現的效果不比 Medusa,但很明顯比原生的 Vicuna-7B 快了!

當然,本來 Hydra 的論文中效果是比 Medusa 要來得更好的,我想這可能有兩個解釋:

  1. Hydra Heads 是存在 Dependency 的,不能像 Medusa 那樣直接一個 hidden states 多個 heads 一起解碼,可能對高度平行化的 GPU 計算圖不友善
  2. 我覺得更有可能是我實現得不好 XD 畢竟 Medusa 是 TensorRT-LLM 官方人員優化好的,而我是想辦法拼湊起來動的,甚至我也覺得自己從頭實現一個 LlamaModel 的 TensorRT-LLM 支援只為了避開記憶體非法存取,可能一路上犧牲了多少優化都不好說

結論

這次對於 Hydra 的 TensorRT-LLM 支援,大概花費了我兩個週末加一整個禮拜的所有業餘時間,大致上計時算起來花費了 30 ~ 35 小時,算是我歷來單一功能的 Side Project 中最花時間的了。

不過這樣自己親自 hack 一次加速推理框架並實現一個新的 Speculative Decoding 支援,感覺還是滿好的。雖然有很多模組是可以直接按照 Medusa 的既有實現來走、但也有不少是 Medusa 所不會遇到的麻煩。

這次的紀錄很難說把實現過程中的所有細節都呈現出來,畢竟一路上 debug 太多回了,甚至有不知道多少次差點就放棄了乾脆去做下個專案 XD

但這樣一來至少在 TensorRT-LLM 支援樹狀解碼的 Speculative Decoding 部份是真的比較熟了,接下來或許還可以考慮玩玩看其他 Speculative Decoding 的支援、或是把 Hydra 的支援再優化一下 —— 我想至少還有 C++ Backend 以及 CUDA Kernel 的部份可以繼續強化。

有興趣的話可以去看看我的實現:https://github.com/ccs96307/TensorRT-LLM/tree/support-spec-decode-hydra/examples/hydra

我希望能很快補上說明文件!感謝~


References


Read More

Leave a Reply