Skip to content

Optimizing LayerSkip Models with Bayesian Search for an Effective Layer Skipping Strategy

Last Updated on 2024-11-15 by Clay

In self-speculative decoding, since our draft model is derived from part of the target model’s network, finding an optimal 'Layer Skip Strategy' is crucial. We need to skip enough layers to achieve meaningful speedup while ensuring the draft model’s speculative decoding is good enough to avoid frequent rejection by the target model.

Today’s implementation focuses on optimizing my previously implemented LayerSkip model using the Bayesian optimization framework Optuna, to determine which layers to skip.


Background Review

One challenge of writing a series of articles is that I often start in the middle, requiring me to add some background details to provide context.

For a detailed explanation of speculative decoding, you can refer to: [Paper Reading] Fast Inference from Transformers via Speculative Decoding and Speculative Decoding Implementation Note (with Simple Experimental Results)

If you want a simple introduction to Bayes’ Theorem, you can check out: A Note of Bayes' Theorem

For my implementation of the LayerSkip model in Self-Speculative Decoding, refer to: Self-Speculative Decoding Implementation: LayerSkip Transformer

Feel free to visit my GitHub repository, where I will continue updating implementations of inference acceleration techniques: https://github.com/ccs96307/fast-llm-inference


Implementation Description

Optuna is a well-known Python package often used for hyperparameter optimization in model training. In simple terms, it utilizes Bayesian optimization to automatically search for defined hyperparameter combinations and evaluates them against the defined objective function to find the optimal combination.

In layman’s terms, when our true objective cannot be differentiated, and the intermediate processes are a black box, we can define the inputs and observe the outputs. Bayesian optimization helps us identify better combinations — though not necessarily the absolute best, as exhaustive enumeration is infeasible.

Can’t we just enumerate the Layer Skip combinations to find the best one? Assuming 20 layers, each with an Attention and MLP, the number of combinations is 2 ^ 40 = 1,099,511,627,776.

Below, I’ve divided the implementation into several functions:

  • calculate_continuous_acceptance(): Calculates the acceptance rate of the target model for tokens predicted by the draft model (tokens must be consecutively accepted).
  • drafter_speculative_decoding(): Decoding function for the draft model, including outputting multiple consecutive probabilities.
  • target_speculative_decoding(): Decoding function for the target model, which validates draft model probabilities.
  • objective(): Search function for Optuna.

Although the paper focuses on optimizing for testing speed, I haven’t set up GPUs or test data yet. For now, I tested with a single sentence and a 2B Gemma model, defining the optimization objective as acceptance rate. The real test results will be updated on GitHub and in this post in the future.


Full Implementation

import optuna
from typing import Dict, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    draft_mode: bool = True
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_model.set_draft_mode(draft_mode)
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    draft_model.set_draft_mode(True)

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    target_model.set_draft_mode(False)
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


def objective(trial):
    # Define search space, sssume we can skip up to six layers
    total_layers = 26

    # Determine skip or not for `attn`
    skip_attn_layers = []
    for i in range(total_layers):
        skip = trial.suggest_int(f'skip_attn_layer_{i}', 0, 1)
        if skip == 1:
            skip_attn_layers.append(i)

    # Determine skip or not for `mlp`
    skip_mlp_layers = []
    for i in range(total_layers):
        skip = trial.suggest_int(f'skip_mlp_layer_{i}', 0, 1)
        if skip == 1:
            skip_mlp_layers.append(i)

    # Disable set to 0 both
    if len(skip_attn_layers) == 0 and len(skip_mlp_layers) == 0:
        raise optuna.TrialPruned()

    skip_layer_ids = {
        "attn": skip_attn_layers,
        "mlp": skip_mlp_layers,
    }

    # Set the skip strategy
    model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]

    input_text = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    is_end = False

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs["input_ids"].shape[1]

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = 5
    max_new_tokens = 100

    while not is_end:
        # Draft model
        target_inputs, draft_probs = drafter_speculative_decode(
            draft_model=model,
            draft_tokenizer=tokenizer,
            inputs=inputs,
            gamma=gamma,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens = target_speculative_decode(
            target_model=model,
            target_tokenizer=tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
        )

        total_accept_tokens += accept_tokens
        inputs = outputs

        if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
            break

    # Compute acceptance rate
    accept_rate = total_accept_tokens / total_draft_tokens

    print(f"attn_skip: {skip_attn_layers}, mlp_skip: {skip_mlp_layers}, Accept Rate: {accept_rate}")

    # Assume we want to maximize `accept_rate`
    return accept_rate


if __name__ == "__main__":
    pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)

    # Init
    skip_layer_ids = {
        "attn": [],
        "mlp": [],
    }

    model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    # Create
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=50)

    print("The best params:", study.best_params)
    print("The best accept_rate:", study.best_value)


Output:

[I 2024-11-13 13:26:42,633] Trial 49 finished with value: 0.04390243902439024 and parameters: {'skip_attn_layer_0': 1, 'skip_attn_layer_1': 0, 'skip_attn_layer_2': 1, 'skip_attn_layer_3': 1, 'skip_attn_layer_4': 0, 'skip_attn_layer_5': 0, 'skip_attn_layer_6': 1, 'skip_attn_layer_7': 0, 'skip_attn_layer_8': 1, 'skip_attn_layer_9': 1, 'skip_attn_layer_10': 0, 'skip_attn_layer_11': 0, 'skip_attn_layer_12': 0, 'skip_attn_layer_13': 0, 'skip_attn_layer_14': 1, 'skip_attn_layer_15': 1, 'skip_attn_layer_16': 1, 'skip_attn_layer_17': 0, 'skip_attn_layer_18': 0, 'skip_attn_layer_19': 1, 'skip_attn_layer_20': 1, 'skip_attn_layer_21': 0, 'skip_attn_layer_22': 0, 'skip_attn_layer_23': 0, 'skip_attn_layer_24': 0, 'skip_attn_layer_25': 1, 'skip_mlp_layer_0': 0, 'skip_mlp_layer_1': 0, 'skip_mlp_layer_2': 0, 'skip_mlp_layer_3': 0, 'skip_mlp_layer_4': 1, 'skip_mlp_layer_5': 0, 'skip_mlp_layer_6': 0, 'skip_mlp_layer_7': 1, 'skip_mlp_layer_8': 1, 'skip_mlp_layer_9': 1, 'skip_mlp_layer_10': 1, 'skip_mlp_layer_11': 0, 'skip_mlp_layer_12': 0, 'skip_mlp_layer_13': 0, 'skip_mlp_layer_14': 1, 'skip_mlp_layer_15': 0, 'skip_mlp_layer_16': 1, 'skip_mlp_layer_17': 0, 'skip_mlp_layer_18': 0, 'skip_mlp_layer_19': 1, 'skip_mlp_layer_20': 1, 'skip_mlp_layer_21': 1, 'skip_mlp_layer_22': 0, 'skip_mlp_layer_23': 0, 'skip_mlp_layer_24': 0, 'skip_mlp_layer_25': 1}. Best is trial 24 with value: 0.15.
The best params: {'skip_attn_layer_0': 0, 'skip_attn_layer_1': 0, 'skip_attn_layer_2': 0, 'skip_attn_layer_3': 0, 'skip_attn_layer_4': 1, 'skip_attn_layer_5': 0, 'skip_attn_layer_6': 0, 'skip_attn_layer_7': 0, 'skip_attn_layer_8': 0, 'skip_attn_layer_9': 1, 'skip_attn_layer_10': 0, 'skip_attn_layer_11': 0, 'skip_attn_layer_12': 0, 'skip_attn_layer_13': 0, 'skip_attn_layer_14': 0, 'skip_attn_layer_15': 1, 'skip_attn_layer_16': 0, 'skip_attn_layer_17': 1, 'skip_attn_layer_18': 1, 'skip_attn_layer_19': 1, 'skip_attn_layer_20': 1, 'skip_attn_layer_21': 0, 'skip_attn_layer_22': 0, 'skip_attn_layer_23': 0, 'skip_attn_layer_24': 0, 'skip_attn_layer_25': 1, 'skip_mlp_layer_0': 0, 'skip_mlp_layer_1': 0, 'skip_mlp_layer_2': 1, 'skip_mlp_layer_3': 0, 'skip_mlp_layer_4': 0, 'skip_mlp_layer_5': 0, 'skip_mlp_layer_6': 0, 'skip_mlp_layer_7': 1, 'skip_mlp_layer_8': 1, 'skip_mlp_layer_9': 1, 'skip_mlp_layer_10': 1, 'skip_mlp_layer_11': 0, 'skip_mlp_layer_12': 0, 'skip_mlp_layer_13': 0, 'skip_mlp_layer_14': 1, 'skip_mlp_layer_15': 0, 'skip_mlp_layer_16': 1, 'skip_mlp_layer_17': 0, 'skip_mlp_layer_18': 0, 'skip_mlp_layer_19': 1, 'skip_mlp_layer_20': 1, 'skip_mlp_layer_21': 1, 'skip_mlp_layer_22': 0, 'skip_mlp_layer_23': 0, 'skip_mlp_layer_24': 0, 'skip_mlp_layer_25': 1}
The best accept_rate: 0.15

The final test results were disappointing, with an acceptance rate of only 0.15. However, based on my preliminary tests, larger models tend to tolerate layer skipping better.


References


Read More

Leave a ReplyCancel reply

Exit mobile version