Skip to content

Accelerating vLLM with Arctic Inference and Custom Speculators

Last Updated on 2025-05-10 by Clay

Generation Latency: A Bottleneck in LLM Applications

Today, I’d like to share an exciting research project and open-source framework developed by Snowflake: Arctic Inference.

Generation latency remains the primary bottleneck in many LLM-powered applications. While Speculative Decoding is a promising approach to mitigate this issue, current open-source implementations still have several limitations:

  • During self-reflection loops and multi-hop reasoning, speculative models often predict only a few tokens ahead—even when sequences are highly repetitive.
  • There's no standardized training framework for building custom draft models.
  • System-level inefficiencies (e.g., communication overhead) prevent draft models from reaching their theoretical speedups.

Arctic Inference: Engineering Optimizations

Arctic Inference addresses these shortcomings with a number of optimizations:

  • FP8 Quantization: Reduces memory consumption and improves draft model latency by using FP8 precision.
  • Tensor Parallelism (TP): Distributes computation across multiple GPUs. (Note: for single-GPU setups like laptops, TP won’t provide benefits.)
  • Communication Optimization: Reduces inter-GPU communication by modifying the logit aggregation pipeline.
    • Initial: Logits(Sharded) -> AllGather -> Logits(Global) -> TopK(Global)
    • Optimized: Logits(Sharded) -> TopK(Sharded) -> AllGather -> TopK(Global)
  • CUDA Graphs: Captures the entire speculative decoding loop as a CUDA graph, reducing kernel launch overhead and improving throughput.
  • Greedy Decoding Match: Uses greedy decoding for verification instead of rejection sampling, improving verification speed.
  • Suffix Decoding Integration: See below.

Suffix Decoding for Repetitive Generation

Suffix Decoding is designed for highly repetitive sequences and maintains:

  • One suffix tree for the current sequence (prompt + ongoing generation)
  • One suffix tree for previously generated tokens
  • Speculative candidates are ranked by frequency matches within the suffix structure

Reference: SUFFIXDECODING: A MODEL-FREE APPROACH TO SPEEDING UP LARGE LANGUAGE MODEL INFERENCE


Arctic Training: Improving Speculation for Non-Repetitive Generation

Arctic Training supports training of two types of draft models:

  • MLP-based Speculators: Simple feed-forward models
  • LSTM-based Speculators: Utilize full LSTM gating (input, forget, output, cell) to model temporal dependencies

Experiment Results

Suffix Decoding and Arctic Inference show significant speedup across multiple test settings:


Quickstart Guide: How To Use Arctic Inference

Install Arctic Inference:

pip install arctic-inference[vllm]


Use one of the pre-trained speculators:

  • Snowflake/Arctic-LSTM-Speculator-Llama-3.1-8B-Instruct
  • Snowflake/Arctic-LSTM-Speculator-Llama-3.1-70B-Instruct
  • Snowflake/Arctic-LSTM-Speculator-Llama-3.3-70B-Instruct
  • Snowflake/Arctic-LSTM-Speculator-Qwen2.5-32B-Instruct

Run inference with vllm:

vllm serve \
    meta-llama/Llama-3.1-70B-Instruct \
    --quantization "fp8" \
    --tensor-parallel-size 2 \
    --speculative-config '{
        "method": "arctic",
        "model":"Snowflake/Arctic-LSTM-Speculator-Llama-3.1-70B-Instruct",
        "num_speculative_tokens": 3,
        "enable_suffix_decoding": true
    }'

⚠️ If your service hangs, try adding --seed 1 to stabilize worker behavior.


Training Custom Speculators: A Case Study with gemma-2-9b-it

Main Question: Can Arctic Inference reproduce the speedups on other models and tasks?

I trained a custom LSTM speculator for a RAG-based customer service chatbot using gemma-2-9b-it.

Step 1: Modify data_gen_script_maker.py

Replace the default dataset logic with your own:

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--model_name", required=True)
parser.add_argument("--data_save_folder_name", required=True)
parser.add_argument("--vllm_tensor_parallel", type=int, default=1)
parser.add_argument("--script_save_path", required=True)
parser.add_argument("--total_num_of_scripts", type=int, default=8)
args = parser.parse_args()
print(args)

model_name = tokenizer_name = args.model_name
data_save_folder_name = args.data_save_folder_name
script_save_path = args.script_save_path
os.makedirs(script_save_path, exist_ok=True)
vllm_tensor_parallel = args.vllm_tensor_parallel

## Ultrachat generation
# total_num_of_scripts = args.total_num_of_scripts
# for i in range(total_num_of_scripts):
#     output_dir = f"{data_save_folder_name}/ultrachat"
#     json_save_path = f"{output_dir}/{i}_{total_num_of_scripts}.jsonl"
#     script = f"""
# python speculator/data_generation/vllm_data_generation.py --model={model_name} --tensor_parallel={vllm_tensor_parallel} --tokenizer={tokenizer_name} --cur_split={i} --output_dataset_path={json_save_path} --total_split={total_num_of_scripts}
#     """
#     with open(f"{script_save_path}/{data_save_folder_name}_{i:02}.sh", "w") as f:
#         f.write(script)

# ## Magicoder generation
# for i in range(total_num_of_scripts):
#     output_dir = f"{data_save_folder_name}/magicoder"
#     json_save_path = f"{output_dir}/{i}_{total_num_of_scripts}.jsonl"
#     script = f"""
# python speculator/data_generation/vllm_data_generation.py --hf_dataset magicoder --model={model_name} --tensor_parallel={vllm_tensor_parallel} --tokenizer={tokenizer_name} --cur_split={i} --output_dataset_path={json_save_path} --total_split={total_num_of_scripts}
#     """
#     with open(f"{script_save_path}/{data_save_folder_name}_magic_{i:02}.sh", "w") as f:
#         f.write(script)


# Custom Service 
total_num_of_scripts = args.total_num_of_scripts
for i in range(total_num_of_scripts):
    output_dir = f"{data_save_folder_name}/customer_data"
    json_save_path = f"{output_dir}/{i}_{total_num_of_scripts}.jsonl"
    script = f"""
python speculator/data_generation/vllm_data_generation.py --hf_dataset customer_service_data --model={model_name} --tensor_parallel={vllm_tensor_parallel} --tokenizer={tokenizer_name} --cur_split={i} --output_dataset_path={json_save_path} --total_split={total_num_of_scripts}
    """
    with open(f"{script_save_path}/{data_save_folder_name}_{i:02}.sh", "w") as f:
        f.write(script)



Step 2: Load the Local Dataset

In vllm_data_generation.py, add:

def load_hf_dataset(dataset):
    if dataset == "ultrachat":
        return load_dataset(
            "HuggingFaceH4/ultrachat_200k",
            split="train_sft",
            num_proc=32,
        )
     

...

    elif dataset == "customer_service_data":
        return load_from_disk("local_dataset/reflected_gemma2_dataset_20250505")["train"]

    else:
        print(f"Dataset {dataset} not supported")
        exit(0)


Step 3: Add YAML + Launch Script

Create your own gemma2-9b.yaml and corresponding shell script based on llama3.1-8b.sh.


Step 4: Train

Run the script:

bash scripts/gemma2-9b.sh


After training, you’ll get a config.json and pytorch_model.bin for inference with vllm.


Evaluation: My Acceleration Results

Hardware: 2× RTX 4090
Framework: vLLM v0.8.4
Task: 1,500+ real-world QA samples (not seen during training)

SetupTokens/sec
Original vLLM112.27
Suffix Decoding223.08
LSTM Speculator194.62
LSTM Speculator + Suffix Decoding221.23
gemma-2-9b-it w/ speculator evaluate at different setups.

Since my chatbot architecture is RAG-based, the results are expected. Notably, both Suffix Decoding and the LSTM Speculator outperform the baseline vLLM setup.

Whatever, I think this recipe can make a robust reproduce result, for different task and different model architecture; This recipe even has the potential to become a standard approach for future model deployment!

However, during testing with vLLM and Arctic Inference, I did encounter occasional CUDA block errors. While the approach is highly promising, it still needs time to mature and stabilize.


References


Read More

Leave a ReplyCancel reply

Exit mobile version