Skip to content

使用 HuggingFace `transformers` 套件中模型的 `assistant_model` 方法來進行 Speculative Decoding 的加速

Last Updated on 2024-11-18 by Clay

最近嘗試實作了許多推測性解碼(Speculative Decoding)的加速方法,而 HuggingFace 的 transformers 套件中自然也有對應的加速方法 assistant_model,今天就趁這個機會一起紀錄下來。

不過需要注意的是,在要使用這些方法前,建議先開個 Python 虛擬環境並把 transformers 升級到最新版本。


`assistant_model` 使用方法

如果是想要了解 Speculative Decoding 的原理,可以參考原始論文:Fast Inference from Transformers via Speculative Decoding

或是我的筆記:[論文閱讀] Fast Inference from Transformers via Speculative Decoding

而若是想要在 transformers 中使用 Speculative Decoding 技術也非常地簡單,我們可以在模型使用 .generate() 方法進行解碼時,透過 assistant_model 這個參數傳遞 draft model 進去進行加速解碼。

import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def main() -> None:
    # Settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    target_model_path = "./models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
    draft_model_path = "./models/HuggingFaceTB--SmolLM2-135M-Instruct"

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenizer
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    # Tokenize
    input_text=target_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Target Model Generate Directly
    start_time = time.time()
    outputs = target_model.generate(**inputs, max_new_tokens=100)

    generated_token_num = outputs.shape[-1] - inputs["input_ids"].shape[-1]

    print("=== Directly Generate ===")
    print(f"Generated Tokens: {generated_token_num}")
    print(f"Spent Time: {time.time() - start_time} seconds.\n")

    # Speculative Decoding
    print("=== Speculative Decoding ===")
    start_time = time.time()
    outputs = target_model.generate(
        **inputs,
        max_new_tokens=100,
        assistant_model=draft_model,
    )

    generated_token_num = outputs.shape[-1] - inputs["input_ids"].shape[-1]

    print(f"Generated Tokens: {generated_token_num}")
    print(f"Spent Time: {time.time() - start_time} seconds.")


if __name__ == "__main__":
    main()


Output:

=== Directly Generate ===
Generated Tokens: 100
Spent Time: 1.9954736232757568 seconds.

=== Speculative Decoding ===
Generated Tokens: 100
Spent Time: 1.9073119163513184 seconds.

不過,由於我測試的量級大小了,比較測試不出加速的提昇。


現在已經支援不使用同樣的 Tokenizer 一樣可以做 Speculative Decoding

另外,以前我們都會說,要讓 draft model 能夠加強 target model,兩者需要享有同樣的詞彙表,換言之也就是需要使用相同的 tokenizer —— 這是因為本來的採樣方法進行驗證時,會需要針對同樣位置的 token 進行機率分佈的驗證。

不過現在 HuggingFace 其實是支援不同詞彙表的 draft model 來加速解碼的哦!詳情可以查看 Universal Assisted Generation: Faster Decoding with Any Assistant Model 這篇文章。

概念其實非常簡單:我們可以先把 draft model 解碼的結果 decode 回字串、再使用 target tokenizer 解碼回 Tokens,就可以讓 target model 進行驗證了!不過當然這樣的話我們目前只能使用 token 做 greedy search 的驗證,而沒辦法基於論文中提出來的 Speculative Sampling 進行驗證。

不過這或許真的是一個非常有用的技術,畢竟有許多需要被加速的大模型,他們未必都有一個享有相同詞彙表的小尺寸版本。

import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def main() -> None:
    # Settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    target_model_path = "./models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
    draft_model_path = "./models/openai-community--gpt2"

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
    print(f"draft tokenizer vocab size: {len(draft_tokenizer)}")
    print(f"target tokenizer vocab size: {len(target_tokenizer)}\n")

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenizer
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    # Tokenize
    input_text=target_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Target Model Generate Directly
    start_time = time.time()
    outputs = target_model.generate(**inputs, max_new_tokens=100)

    generated_token_num = outputs.shape[-1] - inputs["input_ids"].shape[-1]

    print("=== Directly Generate ===")
    print(f"Generated Tokens: {generated_token_num}")
    print(f"Spent Time: {time.time() - start_time} seconds.\n")

    # Speculative Decoding
    print("=== Speculative Decoding ===")
    start_time = time.time()
    outputs = target_model.generate(
        **inputs,
        max_new_tokens=100,
        assistant_model=draft_model,
        tokenizer=target_tokenizer,
        assistant_tokenizer=draft_tokenizer,
    )

    generated_token_num = outputs.shape[-1] - inputs["input_ids"].shape[-1]

    print(f"Generated Tokens: {generated_token_num}")
    print(f"Spent Time: {time.time() - start_time} seconds.")


if __name__ == "__main__":
    main()


Output:

draft tokenizer vocab size: 50257
target tokenizer vocab size: 49152

=== Directly Generate ===
Generated Tokens: 100
Spent Time: 2.0288448333740234 seconds.

=== Speculative Decoding ===
Generated Tokens: 100
Spent Time: 2.306903839111328 seconds.

可以看到我們確實可以使用兩個擁有不同詞彙表的 draft model 和 target model,並能用於 Speculative Decoding,但是我們就需要額外傳入兩者的 Tokenizers 了。畢竟,現在在生成驗證階段會用到兩者的 Tokenizers 進行編碼與再次解碼。


References


Read More

Leave a Reply取消回覆

Exit mobile version