Skip to content

Using Finite State Machine (FSM) and Rollback Mechanism to Restrict LLM from Generating Banned Words

Last Updated on 2024-10-29 by Clay

When implementing various services through LLMs, do you worry about uncontrolled language generation? Recently, at a critical juncture in wrapping up a project, I used tools like Outlines to constrain LLM decoding, which effectively controlled the model's output to follow the desired patterns. However, a colleague posed a deep question: What if I want it not to generate specific words?

That’s indeed a great question! Our PM once asked the same thing, but at the time, with more pressing deadlines, I reflexively replied that it was challenging. Only now am I seriously considering it and realizing that theoretically, it should be achievable.

(2024/10/28 Update: I've uploaded the source code here: https://github.com/ccs96307/llm-decode-filter-special-words. Feedback is very welcome!)


Overview of Existing Techniques

I began by researching similar work:

Logit Biasing and Token Masking: Some models support modifying token probability distributions directly, like OpenAI’s logit_bias parameter, which excludes or amplifies specific tokens in decoding.

Regex-based Constraints: Previously explored in this article on Outlines constraints, this approach leverages a finite state machine to control the LLM’s current state and restrict its decoding range.

Direct Token Masking: The simplest approach is to mask specific tokens, ensuring they are never generated.

However, these solutions don’t entirely fit my needs because my constraints should apply to words rather than individual tokens (for example, "ham-bur-ger" forms a word but consists of three tokens). Say, if I want to block "fuck you" as a curse word but not "you" alone, so that normal "you" usage in conversations remains unaffected.

Of course, given my limited search capability, I may have missed solutions that match my needs. If anyone knows of similar approaches, I would love to hear about them and give proper credit.


Implementation Approach

My approach involves creating an initial “restricted vocabulary FSM” for each model, as each model's vocabulary may vary.

First, we build the finite state machine:

  1. Define a list of restricted words
  2. Enumerate all possible tokens that could start decoding as restricted words
  3. From the first decoding token of a restricted word, continue enumerating the vocabulary until forming all token combinations that could decode the restricted words list. Set each token_id in the FSM, and when reaching the final token of a word, set it to a state of -1, signifying a “successful match.”

Then, we enter the decoding process:

  1. If the token decoded by the model doesn’t match a restricted token from state 0 in the FSM, it remains in state 0. But if it matches, it transitions to the corresponding next state.
  2. If the next decoded token is not in the restricted token list for the current state, it returns to state 0; otherwise, it proceeds through states. If a “successful match” is achieved, it transitions to -1 and enters “roll-back mode.”
  3. In roll-back mode, decoding rolls back to the first token in the restricted sequence and prevents that token from decoding. This banned token is stored in a {position_idx: [token_id]} list, so when decoding reaches position_idx, it will first check this dictionary to see if any tokens are banned, masking them as necessary.

With this algorithm, we can entirely prevent the model from generating specific words.


Implementation Details

In the example below, we use the Gemma-2 model architecture. Ideally, I aim for this decoding algorithm to support various models. However, the vocabulary range for different models can vary, so some rules may not apply universally.

First, we prepare all configurations:

from typing import List, Dict, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Model and Tokenizer
pretrained_model_name_or_path = "./models/google--gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
model.to(device)



Step 1. Define a List of Restricted Words

Here, I define the restricted words as banned_words.

banned_words = ["talk", "listen", "fuck you"]



Step 2. Enumerate All Possible Generation Paths for Restricted Words

There are a few important considerations here: Not all tokens will directly match banned_words due to tokens like ▁listen found in models like Gemma-2 and GPT-2, where signifies a “boundary” or “word start.” We need to remove it for matching with banned_words.

Additionally, some tokens in our vocabulary have special characters; enumerating all possible decoding paths would result in exponential growth in computation. Thus, I use re.compile(r'^[\t\n\r\f\v\s]*$') to match and discard such tokens.

import re

new_token_candidates = []
new_token_id_candidates = []

# Define a pattern to filter out unwanted excessive whitespace tokens (e.g., multiple tabs or newlines)
invalid_token_pattern = re.compile(r'^[\t\n\r\f\v\s]*$')  # Matches tokens that are made of 3 or more whitespace characters

# Initial matching of tokens that start with banned word prefixes
for token, token_id in tokenizer.vocab.items():
    token_stripped = token.strip()
    if len(token_stripped) == 0 or (len(token_stripped) == 1 and token_stripped == "▁"): 
        continue

    if token_stripped[0] == "▁":
        token_stripped = token_stripped[1:]

    for banned_word in banned_words:
        if banned_word.startswith(token_stripped):
            new_token_candidates.append([token])  # Store as a list for consistency
            new_token_id_candidates.append([token_id])

print("Filtered First Candidates:", new_token_candidates)
print("Filtered First IDs:", new_token_id_candidates)

# Initialize final routes for fully matched banned words
final_routes = []
final_id_routes = []
for new_token, new_token_id in zip(new_token_candidates, new_token_id_candidates):
    new_token_str = "".join(new_token).strip()

    if new_token_str in banned_words or (new_token_str[0] == "▁" and new_token_str[1:] in banned_words):
        final_routes.append(new_token)
        final_id_routes.append(new_token_id)

print("(beginning) Final Routes:", final_routes)
print("(beginning) Final ID Routes:", final_id_routes)

# Iteratively expand candidates using BFS-like approach
while new_token_candidates:
    curr_token_candidates = new_token_candidates
    curr_token_id_candidates = new_token_id_candidates
    new_token_candidates = []
    new_token_id_candidates = []

    # Iterate over vocabulary and expand each candidate
    for token, token_id in tokenizer.vocab.items():
        token_stripped = token.strip()
        if len(token_stripped) == 0 or (len(token_stripped) == 1 and token_stripped == "▁") or re.findall(r"\▁{2,100}", token_stripped):
            continue

        # Skip tokens that consist of excessive whitespace or control characters
        if invalid_token_pattern.match(token_stripped):
            continue

        for candidate_token, candidate_token_id in zip(curr_token_candidates, curr_token_id_candidates):
            curr_token = candidate_token + [token]
            curr_token_ids = candidate_token_id + [token_id]

            curr_token_str = "".join([token.replace("▁", " ") for token in curr_token]).strip()
            # if curr_token_str[0] == "▁":
            #     curr_token_str = curr_token_str[1:]

            # Check if the current token combination matches or is a prefix of any banned word
            for banned_word in banned_words:
                if curr_token_str == banned_word:
                    # Full match found, add to final routes
                    final_routes.append(curr_token)
                    final_id_routes.append(curr_token_ids)
                elif banned_word.startswith(curr_token_str):
                    # Partial match, keep expanding this candidate
                    new_token_candidates.append(curr_token)
                    new_token_id_candidates.append(curr_token_ids)

    print(final_routes)

print("Final Routes:", final_routes)
print("Final ID Routes:", final_id_routes)


Now, our decoded generation paths can be represented as:

['listen'] [18998]
['▁listen'] [10724]
['▁talk'] [5063]
['talk'] [33085]
['▁tal', 'k'] [3412, 235273]
['tal', 'k'] [3559, 235273]
['li', 'sten'] [515, 5547]
['▁li', 'sten'] [702, 5547]
['ta', 'lk'] [516, 26159]
['▁ta', 'lk'] [3586, 26159]
['lis', 'ten'] [15063, 965]
['▁lis', 'ten'] [23966, 965]
['▁t', 'alk'] [474, 2071]
['t', 'alk'] [235251, 2071]
['▁liste', 'n'] [32165, 235254]
['liste', 'n'] [44003, 235254]
['▁list', 'en'] [1889, 479]
['list', 'en'] [1701, 479]
['l', 'isten'] [235257, 17071]
['▁l', 'isten'] [533, 17071]
['fuck', '▁you'] [34024, 692]
['▁fuck', '▁you'] [7935, 692]
['ta', 'l', 'k'] [516, 235257, 235273]
['▁ta', 'l', 'k'] [3586, 235257, 235273]
['▁t', 'al', 'k'] [474, 492, 235273]
['t', 'al', 'k'] [235251, 492, 235273]
['l', 'i', 'sten'] [235257, 235252, 5547]
['▁l', 'i', 'sten'] [533, 235252, 5547]
['▁t', 'a', 'lk'] [474, 235250, 26159]
['t', 'a', 'lk'] [235251, 235250, 26159]
['l', 'is', 'ten'] [235257, 502, 965]
['▁l', 'is', 'ten'] [533, 502, 965]
['li', 's', 'ten'] [515, 235256, 965]
['▁li', 's', 'ten'] [702, 235256, 965]
['fuck', '▁y', 'ou'] [34024, 597, 507]
['▁fuck', '▁y', 'ou'] [7935, 597, 507]
['l', 'iste', 'n'] [235257, 3671, 235254]
['▁l', 'iste', 'n'] [533, 3671, 235254]
['li', 'ste', 'n'] [515, 2855, 235254]
['▁li', 'ste', 'n'] [702, 2855, 235254]
['lis', 'te', 'n'] [15063, 488, 235254]
['▁lis', 'te', 'n'] [23966, 488, 235254]
['▁list', 'e', 'n'] [1889, 235249, 235254]
['list', 'e', 'n'] [1701, 235249, 235254]
['l', 'ist', 'en'] [235257, 694, 479]
['▁l', 'ist', 'en'] [533, 694, 479]
['li', 'st', 'en'] [515, 490, 479]
['▁li', 'st', 'en'] [702, 490, 479]
['lis', 't', 'en'] [15063, 235251, 479]
['▁lis', 't', 'en'] [23966, 235251, 479]
['▁fu', 'ck', '▁you'] [4936, 623, 692]
['fu', 'ck', '▁you'] [12819, 623, 692]
['▁fuc', 'k', '▁you'] [79433, 235273, 692]
['▁f', 'uck', '▁you'] [517, 1870, 692]
['f', 'uck', '▁you'] [235266, 1870, 692]
['fuck', '▁yo', 'u'] [34024, 10931, 235261]
['▁fuck', '▁yo', 'u'] [7935, 10931, 235261]
['▁t', 'a', 'l', 'k'] [474, 235250, 235257, 235273]
['t', 'a', 'l', 'k'] [235251, 235250, 235257, 235273]
['l', 'i', 's', 'ten'] [235257, 235252, 235256, 965]
['▁l', 'i', 's', 'ten'] [533, 235252, 235256, 965]
['▁fu', 'ck', '▁y', 'ou'] [4936, 623, 597, 507]
['fu', 'ck', '▁y', 'ou'] [12819, 623, 597, 507]
['▁fuc', 'k', '▁y', 'ou'] [79433, 235273, 597, 507]
['▁f', 'uck', '▁y', 'ou'] [517, 1870, 597, 507]
['f', 'uck', '▁y', 'ou'] [235266, 1870, 597, 507]
['l', 'i', 'ste', 'n'] [235257, 235252, 2855, 235254]
['▁l', 'i', 'ste', 'n'] [533, 235252, 2855, 235254]
['l', 'is', 'te', 'n'] [235257, 502, 488, 235254]
['▁l', 'is', 'te', 'n'] [533, 502, 488, 235254]
['li', 's', 'te', 'n'] [515, 235256, 488, 235254]
['▁li', 's', 'te', 'n'] [702, 235256, 488, 235254]
['l', 'ist', 'e', 'n'] [235257, 694, 235249, 235254]
['▁l', 'ist', 'e', 'n'] [533, 694, 235249, 235254]
['li', 'st', 'e', 'n'] [515, 490, 235249, 235254]
['▁li', 'st', 'e', 'n'] [702, 490, 235249, 235254]
['lis', 't', 'e', 'n'] [15063, 235251, 235249, 235254]
['▁lis', 't', 'e', 'n'] [23966, 235251, 235249, 235254]
['l', 'i', 'st', 'en'] [235257, 235252, 490, 479]
['▁l', 'i', 'st', 'en'] [533, 235252, 490, 479]
['l', 'is', 't', 'en'] [235257, 502, 235251, 479]
['▁l', 'is', 't', 'en'] [533, 502, 235251, 479]
['li', 's', 't', 'en'] [515, 235256, 235251, 479]
['▁li', 's', 't', 'en'] [702, 235256, 235251, 479]
['▁f', 'u', 'ck', '▁you'] [517, 235261, 623, 692]
['f', 'u', 'ck', '▁you'] [235266, 235261, 623, 692]
['▁fu', 'c', 'k', '▁you'] [4936, 235260, 235273, 692]
['fu', 'c', 'k', '▁you'] [12819, 235260, 235273, 692]
['▁f', 'uc', 'k', '▁you'] [517, 1669, 235273, 692]
['f', 'uc', 'k', '▁you'] [235266, 1669, 235273, 692]
['fuck', '▁y', 'o', 'u'] [34024, 597, 235253, 235261]
['▁fuck', '▁y', 'o', 'u'] [7935, 597, 235253, 235261]
['▁fu', 'ck', '▁yo', 'u'] [4936, 623, 10931, 235261]
['fu', 'ck', '▁yo', 'u'] [12819, 623, 10931, 235261]
['▁fuc', 'k', '▁yo', 'u'] [79433, 235273, 10931, 235261]
['▁f', 'uck', '▁yo', 'u'] [517, 1870, 10931, 235261]
['f', 'uck', '▁yo', 'u'] [235266, 1870, 10931, 235261]
['▁f', 'u', 'ck', '▁y', 'ou'] [517, 235261, 623, 597, 507]
['f', 'u', 'ck', '▁y', 'ou'] [235266, 235261, 623, 597, 507]
['▁fu', 'c', 'k', '▁y', 'ou'] [4936, 235260, 235273, 597, 507]
['fu', 'c', 'k', '▁y', 'ou'] [12819, 235260, 235273, 597, 507]
['▁f', 'uc', 'k', '▁y', 'ou'] [517, 1669, 235273, 597, 507]
['f', 'uc', 'k', '▁y', 'ou'] [235266, 1669, 235273, 597, 507]
['l', 'i', 's', 'te', 'n'] [235257, 235252, 235256, 488, 235254]
['▁l', 'i', 's', 'te', 'n'] [533, 235252, 235256, 488, 235254]
['l', 'i', 'st', 'e', 'n'] [235257, 235252, 490, 235249, 235254]
['▁l', 'i', 'st', 'e', 'n'] [533, 235252, 490, 235249, 235254]
['l', 'is', 't', 'e', 'n'] [235257, 502, 235251, 235249, 235254]
['▁l', 'is', 't', 'e', 'n'] [533, 502, 235251, 235249, 235254]
['li', 's', 't', 'e', 'n'] [515, 235256, 235251, 235249, 235254]
['▁li', 's', 't', 'e', 'n'] [702, 235256, 235251, 235249, 235254]
['l', 'i', 's', 't', 'en'] [235257, 235252, 235256, 235251, 479]
['▁l', 'i', 's', 't', 'en'] [533, 235252, 235256, 235251, 479]
['▁f', 'u', 'c', 'k', '▁you'] [517, 235261, 235260, 235273, 692]
['f', 'u', 'c', 'k', '▁you'] [235266, 235261, 235260, 235273, 692]
['▁fu', 'ck', '▁y', 'o', 'u'] [4936, 623, 597, 235253, 235261]
['fu', 'ck', '▁y', 'o', 'u'] [12819, 623, 597, 235253, 235261]
['▁fuc', 'k', '▁y', 'o', 'u'] [79433, 235273, 597, 235253, 235261]
['▁f', 'uck', '▁y', 'o', 'u'] [517, 1870, 597, 235253, 235261]
['f', 'uck', '▁y', 'o', 'u'] [235266, 1870, 597, 235253, 235261]
['▁f', 'u', 'ck', '▁yo', 'u'] [517, 235261, 623, 10931, 235261]
['f', 'u','ck', '▁yo', 'u'] [235266, 235261, 623, 10931, 235261]
['▁fu', 'c', 'k', '▁yo', 'u'] [4936, 235260, 235273, 10931, 235261]
['fu', 'c', 'k', '▁yo', 'u'] [12819, 235260, 235273, 10931, 235261]
['▁f', 'uc', 'k', '▁yo', 'u'] [517, 1669, 235273, 10931, 235261]
['f', 'uc', 'k', '▁yo', 'u'] [235266, 1669, 235273, 10931, 235261]
['▁f', 'u', 'c', 'k', '▁y', 'ou'] [517, 235261, 235260, 235273, 597, 507]
['f', 'u', 'c', 'k', '▁y', 'ou'] [235266, 235261, 235260, 235273, 597, 507]
['l', 'i', 's', 't', 'e', 'n'] [235257, 235252, 235256, 235251, 235249, 235254]
['▁l', 'i', 's', 't', 'e', 'n'] [533, 235252, 235256, 235251, 235249, 235254]
['▁f', 'u', 'ck', '▁y', 'o', 'u'] [517, 235261, 623, 597, 235253, 235261]
['f', 'u', 'ck', '▁y', 'o', 'u'] [235266, 235261, 623, 597, 235253, 235261]
['▁fu', 'c', 'k', '▁y', 'o', 'u'] [4936, 235260, 235273, 597, 235253, 235261]
['fu', 'c', 'k', '▁y', 'o', 'u'] [12819, 235260, 235273, 597, 235253, 235261]
['▁f', 'uc', 'k', '▁y', 'o', 'u'] [517, 1669, 235273, 597, 235253, 235261]
['f', 'uc', 'k', '▁y', 'o', 'u'] [235266, 1669, 235273, 597, 235253, 235261]
['▁f', 'u', 'c', 'k', '▁yo', 'u'] [517, 235261, 235260, 235273, 10931, 235261]
['f', 'u', 'c', 'k', '▁yo', 'u'] [235266, 235261, 235260, 235273, 10931, 235261]
['▁f', 'u', 'c', 'k', '▁y', 'o', 'u'] [517, 235261, 235260, 235273, 597, 235253, 235261]
['f', 'u', 'c', 'k', '▁y', 'o', 'u'] [235266, 235261, 235260, 235273, 597, 235253, 235261]


Using these enumerated paths, we can implement a finite state machine to detect any restricted token sequences in the current state. Details on FSM implementation can be referenced in my previous article on FSM for controlling LLM decoding, which provides a simpler implementation.

class FSMProcessor:
    def __init__(self, special_token_ids_list: List[List[int]], end_state: int = -1) -> None:
        self.end_state = end_state
        self.next_state = 1
        self.curr_state = 0
        self.fsm = {}
        self.special_words = []

        # Track partial matches
        self.partial_match_state = None
        self.partial_tokens = []

        self.update_group(special_token_ids_list)

    def update(self, special_token_ids: List[int]) -> None:
        curr_state = 0

        for idx, special_token_id in enumerate(special_token_ids):
            if curr_state not in self.fsm:
                self.fsm[curr_state] = []

            state2id = [items[0] for items in self.fsm[curr_state]]
            if special_token_id not in state2id:
                if idx == len(special_token_ids) - 1:
                    self.fsm[curr_state].append([special_token_id, self.end_state])
                else:
                    self.fsm[curr_state].append([special_token_id, self.next_state])
                    curr_state = self.next_state
                    self.next_state += 1
            else:
                for fsm_idx in range(len(self.fsm[curr_state])):
                    if special_token_id == self.fsm[curr_state][fsm_idx][0] and idx == len(special_token_ids) - 1:
                        self.fsm[curr_state][fsm_idx][1] = self.end_state
                        break
                    elif special_token_id == self.fsm[curr_state][fsm_idx][0]:
                        curr_state = self.fsm[curr_state][fsm_idx][1]
                        break

    def update_group(self, special_token_ids_list: List[List[int]]) -> None:
        for special_token_ids in special_token_ids_list:
            self.update(special_token_ids=special_token_ids)

    def get_fsm_data(self) -> Dict[str, List[Tuple[int, int]]]:
        return self.fsm
    
    def detect(self, token: int) -> bool:
        """
        Detect if the current token leads to a sensitive sequence.
        Updates the current state and returns True if it reaches the end state.
        """
        if self.curr_state in self.fsm:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == token:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state
        
        # If the token does not match, reset the current state
        self.curr_state = 0
        return False
fsm_processor = FSMProcessor(special_token_ids_list=final_id_routes)
fsm_processor.get_fsm_data()


Output:

{0: [[18998, -1],
[10724, -1],
[5063, -1],
[33085, -1],
[3412, 1],
[3559, 2],
[515, 3],
[702, 4],
[516, 5],
[3586, 6],
[15063, 7],
[23966, 8],
[474, 9],
[235251, 10],
[32165, 11],
[44003, 12],
[1889, 13],
[1701, 14],
[235257, 15],
[533, 16],
[34024, 17],
[7935, 18],
[4936, 47],
[12819, 49],
[79433, 51],
[517, 53],
[235266, 55]],
1: [[235273, -1]],
2: [[235273, -1]],
3: [[5547, -1], [235256, 29], [2855, 35], [490, 43]],
4: [[5547, -1], [235256, 30], [2855, 36], [490, 44]],
5: [[26159, -1], [235257, 19]],
6: [[26159, -1], [235257, 20]],
7: [[965, -1], [488, 37], [235251, 45]],
8: [[965, -1], [488, 38], [235251, 46]],
9: [[2071, -1], [492, 21], [235250, 25]],
10: [[2071, -1], [492, 22], [235250, 26]],
11: [[235254, -1]],
12: [[235254, -1]],
13: [[479, -1], [235249, 39]],
14: [[479, -1], [235249, 40]],
15: [[17071, -1], [235252, 23], [502, 27], [3671, 33], [694, 41]],
16: [[17071, -1], [235252, 24], [502, 28], [3671, 34], [694, 42]],
17: [[692, -1], [597, 31], [10931, 57]],
18: [[692, -1], [597, 32], [10931, 58]],
19: [[235273, -1]],
20: [[235273, -1]],
21: [[235273, -1]],
22: [[235273, -1]],
23: [[5547, -1], [235256, 61], [2855, 68], [490, 80]],
24: [[5547, -1], [235256, 62], [2855, 69], [490, 81]],
25: [[26159, -1], [235257, 59]],
26: [[26159, -1], [235257, 60]],
27: [[965, -1], [488, 70], [235251, 82]],
28: [[965, -1], [488, 71], [235251, 83]],
29: [[965, -1], [488, 72], [235251, 84]],
30: [[965, -1], [488, 73], [235251, 85]],
31: [[507, -1], [235253, 98]],
32: [[507, -1], [235253, 99]],
33: [[235254, -1]],
34: [[235254, -1]],
35: [[235254, -1]],
36: [[235254, -1]],
37: [[235254, -1]],
38: [[235254, -1]],
39: [[235254, -1]],
40: [[235254, -1]],
41: [[479, -1], [235249, 74]],
42: [[479, -1], [235249, 75]],
43: [[479, -1], [235249, 76]],
44: [[479, -1], [235249, 77]],
45: [[479, -1], [235249, 78]],
46: [[479, -1], [235249, 79]],
47: [[623, 48], [235260, 90]],
48: [[692, -1], [597, 63], [10931, 100]],
49: [[623, 50], [235260, 92]],
50: [[692, -1], [597, 64], [10931, 101]],
51: [[235273, 52]],
52: [[692, -1], [597, 65], [10931, 102]],
53: [[1870, 54], [235261, 86], [1669, 94]],
54: [[692, -1], [597, 66], [10931, 103]],
55: [[1870, 56], [235261, 88], [1669, 96]],
56: [[692, -1], [597, 67], [10931, 104]],
57: [[235261, -1]],
58: [[235261, -1]],
59: [[235273, -1]],
60: [[235273, -1]],
61: [[965, -1], [488, 111], [235251, 119]],
62: [[965, -1], [488, 112], [235251, 120]],
63: [[507, -1], [235253, 125]],
64: [[507, -1], [235253, 126]],
65: [[507, -1], [235253, 127]],
66: [[507, -1], [235253, 128]],
67: [[507, -1], [235253, 129]],
68: [[235254, -1]],
69: [[235254, -1]],
70: [[235254, -1]],
71: [[235254, -1]],
72: [[235254, -1]],
73: [[235254, -1]],
74: [[235254, -1]],
75: [[235254, -1]],
76: [[235254, -1]],
77: [[235254, -1]],
78: [[235254, -1]],
79: [[235254, -1]],
80: [[479, -1], [235249, 113]],
81: [[479, -1], [235249, 114]],
82: [[479, -1], [235249, 115]],
83: [[479, -1], [235249, 116]],
84: [[479, -1], [235249, 117]],
85: [[479, -1], [235249, 118]],
86: [[623, 87], [235260, 121]],
87: [[692, -1], [597, 105], [10931, 130]],
88: [[623, 89], [235260, 123]],
89: [[692, -1], [597, 106], [10931, 131]],
90: [[235273, 91]],
91: [[692, -1], [597, 107], [10931, 132]],
92: [[235273, 93]],
93: [[692, -1], [597, 108], [10931, 133]],
94: [[235273, 95]],
95: [[692, -1], [597, 109], [10931, 134]],
96: [[235273, 97]],
97: [[692, -1], [597, 110], [10931, 135]],
98: [[235261, -1]],
99: [[235261, -1]],
100: [[235261, -1]],
101: [[235261, -1]],
102: [[235261, -1]],
103: [[235261, -1]],
104: [[235261, -1]],
105: [[507, -1], [235253, 140]],
106: [[507, -1], [235253, 141]],
107: [[507, -1], [235253, 142]],
108: [[507, -1], [235253, 143]],
109: [[507, -1], [235253, 144]],
110: [[507, -1], [235253, 145]],
111: [[235254, -1]],
112: [[235254, -1]],
113: [[235254, -1]],
114: [[235254, -1]],
115: [[235254, -1]],
116: [[235254, -1]],
117: [[235254, -1]],
118: [[235254, -1]],
119: [[479, -1], [235249, 138]],
120: [[479, -1], [235249, 139]],
121: [[235273, 122]],
122: [[692, -1], [597, 136], [10931, 146]],
123: [[235273, 124]],
124: [[692, -1], [597, 137], [10931, 147]],
125: [[235261, -1]],
126: [[235261, -1]],
127: [[235261, -1]],
128: [[235261, -1]],
129: [[235261, -1]],
130: [[235261, -1]],
131: [[235261, -1]],
132: [[235261, -1]],
133: [[235261, -1]],
134: [[235261, -1]],
135: [[235261, -1]],
136: [[507, -1], [235253, 148]],
137: [[507, -1], [235253, 149]],
138: [[235254, -1]],
139: [[235254, -1]],
140: [[235261, -1]],
141: [[235261, -1]],
142: [[235261, -1]],
143: [[235261, -1]],
144: [[235261, -1]],
145: [[235261, -1]],
146: [[235261, -1]],
147: [[235261, -1]],
148: [[235261, -1]],
149: [[235261, -1]]}



Step 3: Standard Decoding vs FSM + Rollback Decoding

Now, let's first see a simple greedy decoding scenario without constraints:

def custom_generate(input_ids: torch.Tensor, max_length: int = 50) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids.to(device), return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]
        
        new_generated_token_id = torch.argmax(logits, dim=-1)

        if new_generated_token_id == tokenizer.eos_token_id:
            break

        input_ids = torch.cat((input_ids, new_generated_token_id.unsqueeze(0)), dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)


input_text = "Can we talk?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
print(custom_generate(input_ids=inputs.input_ids))


Output:

<bos>Can we talk?

I'm here to listen and help in any way I can.

What's on your mind?
<end_of_turn>

As seen, the model generates a fairly standard response. Now, we’ll add our banned_words filter to restrict certain sequences from being generated. Our list includes: ["talk", "listen", "fuck you"].

def custom_generate_with_fsm_filter(
    input_ids: torch.Tensor,
    fsm_processor: FSMProcessor,
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}
    past_key_values = None
    steps = 0

    while steps < max_length:
        steps += 1

        # Generate new token with kv cache
        outputs = model(input_ids, past_key_values=past_key_values, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Update kv cache
        past_key_values = outputs.past_key_values

        # Check if there are already masked tokens at the current step
        if steps in masked_tokens_history:
            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")
        else:
            masked_tokens_history[steps] = set()

        # Decode the generated token
        generated_token_id = torch.argmax(logits, dim=-1).item()
        combined_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=-1)

        # Check FSM for sensitive sequences
        if fsm_processor.detect(generated_token_id):
            # Detected a sensitive sequence, initiate rollback
            rollback_length = fsm_processor.partial_match_state + 1 if fsm_processor.partial_match_state is not None else 1
            steps = steps - rollback_length + 1
            rollbacks_ids = combined_ids[:, :-rollback_length]
            input_ids = rollbacks_ids
            print(f"Rollback detected. Rolling back from step {steps + rollback_length} to step {steps}")

            # Reset FSM state
            fsm_processor.curr_state = 0
            fsm_processor.partial_match_state = None

            # Reset past_key_values when rolling back
            past_key_values = None

            # Recalculate logits based on rolled-back sequence
            outputs = model(input_ids, return_dict=True, use_cache=True)
            logits = outputs.logits[:, -1, :]
            past_key_values = outputs.past_key_values

            # Mask the first token of the sensitive sequence
            first_token_id = generated_token_id
            print(f"Masking token id: {first_token_id}, Masking token: {tokenizer.decode(first_token_id)}")

            # Update the historical mask list to record the token at this step
            masked_tokens_history[steps].add(first_token_id)

            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")

            # Generate the token again after masking
            generated_token_id = torch.argmax(logits, dim=-1).item()

        # Update input_ids with the generated token
        input_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=1)

        print(f"Step {steps}: ID: {generated_token_id} Generated token: {tokenizer.decode(generated_token_id)}")

        if generated_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)


response = custom_generate_with_fsm_filter(
    input_ids=inputs.input_ids,
    fsm_processor=fsm_processor,
    max_length=50,
)
print(response)


Output:

Step 1: ID: 109 Generated token: 


Step 2: ID: 235285 Generated token: I
Step 3: ID: 235303 Generated token: '
Step 4: ID: 235262 Generated token: m
Step 5: ID: 1517 Generated token: here
Step 6: ID: 577 Generated token: to
Rollback detected. Rolling back from step 8 to step 7
Masking token id: 10724, Masking token: listen
Step 7: ID: 1707 Generated token: help
Step 8: ID: 692 Generated token: you
Step 9: ID: 675 Generated token: with
Step 10: ID: 9550 Generated token: whatever
Step 11: ID: 692 Generated token: you
Step 12: ID: 1476 Generated token: need
Step 13: ID: 235265 Generated token: .
Step 14: ID: 235248 Generated token:
Step 15: ID: 109 Generated token:


Step 16: ID: 5958 Generated token: Please
Step 17: ID: 3337 Generated token: tell
Step 18: ID: 682 Generated token: me
Step 19: ID: 1212 Generated token: what
Step 20: ID: 235303 Generated token: '
Step 21: ID: 235256 Generated token: s
Step 22: ID: 611 Generated token: on
Step 23: ID: 861 Generated token: your
Step 24: ID: 3403 Generated token: mind
Step 25: ID: 235265 Generated token: .
Step 26: ID: 44416 Generated token: 😊
Step 27: ID: 108 Generated token:

Step 28: ID: 107 Generated token: <end_of_turn>
Step 29: ID: 1 Generated token: <eos>

<bos>Can we talk?

I'm here to help you with whatever you need.

Please tell me what's on your mind. 😊
<end_of_turn><eos>

As shown, when the model initially generated the ▁listen token, the FSM successfully detected it and used the rollback mechanism to reset the decoding sequence to before that token was generated, effectively masking it (Token ID=10724). This demonstrates a simple but effective FSM-based decoding constraint to restrict certain words from appearing. Of course, there may still be edge cases and exceptions, and I’ll continue optimizing this solution as time allows.

Thank you for reading, and feel free to provide feedback!


References


Read More

Leave a Reply