Skip to content

Implementation of Using Finite-State Machine to Constrain Large Language Model Decoding

Last Updated on 2024-09-05 by Clay

This is a simple Python implementation, used to test Finite-State Machine (FSM) constraints for a Large Language Model (LLM) to decode responses in a specific format. It also serves as an introduction to the concept behind the Outlines tool. Of course, my implementation is far simpler compared to the actual Outlines tool.

Let's start by quickly reviewing the LLM decoding process:

LLMs are generally based on a decoder-only architecture, designed specifically for decoding. The unit of decoding is a Token, not the words we typically understand (although implementations for languages like Chinese, Japanese, and Korean have advanced to the point where one character or word often represents a single token).

During decoding, the LLM doesn't complete the process in one go; instead, it decodes tokens incrementally based on the context, appending each decoded token to the input before decoding the next token. This process continues until a stop token is generated or the system reaches the predefined token limit.

This is where the Finite-State Machine comes in to constrain the model's behavior. Throughout the generation process, we need to determine which state the model is in, and only accept tokens that can be decoded based on that state. Conceptually, this is similar to a Trie.


Code Implementation

Now I will define a finite-state machine that only allows the LLM to decode the candidates "hot," "code," and "hotel," using an asterisk (*) as the end token. This is because without an end token, it would be hard to know whether to stop after "hot" or continue decoding "hotel."

First, we define the finite-state machine diagram. Each circle represents a state, and the edges represent transitions. I deviated from convention by using state -1 as the end state, which might get me in trouble with FSM purists!

For example, we can see that in state 3 (state=3), only "*" or "e" are allowed to be decoded. If "*" is chosen, the final decoded result is "hot"; otherwise, it will be "hotel."

Below is the Python implementation. Here, I simulate decoding not with tokens, but by decoding one character (a-z, A-Z) at a time:

from typing import List, Tuple, Dict
import random


class FSMDecoder:
    def __init__(self, valid_words: List[str]) -> None:
        self.valid_words = valid_words
        self.states, self.final_states = self._create_fsm(valid_words=valid_words)
        self.current_state = 0

    def _create_fsm(self, valid_words: List[str]) -> Tuple[Dict[Tuple[int, str], int], set]:
        # Init
        states = {}
        state_id = 0
        root_state = 0
        end_state_id = -1

        for word in valid_words:
            current_state = root_state
            for char in word:
                if char == "*":
                    states[(current_state, char)] = end_state_id
                    break
                elif (current_state, char) not in states:
                    state_id += 1
                    states[(current_state, char)] = state_id
                current_state = states[(current_state, char)]
        
        final_states = {current_state for word in valid_words for char in word if '*' in word}
        
        return states, final_states

    def decode(self, char: str) -> bool:
        if (self.current_state, char) in self.states:
            self.current_state = self.states[(self.current_state, char)]
            return True
        else:
            return False

    def is_valid(self) -> bool:
        return self.current_state in self.final_states


if __name__ == "__main__":
    # Example usage
    candidate_words = ["hot*", "cold*", "hotel*"]

    fsm_decoder = FSMDecoder(candidate_words)

    for key, value in fsm_decoder.states.items():
        print(f"{key} -> {value}")

    print()

    # Assume `*` is the end special token
    decode_list = ["*"]

    decode_list.extend(list(map(chr, range(65, 91))))
    decode_list.extend(list(map(chr, range(97, 123))))

    # Try to random decoding!
    decode_word = ""

    while True:
        candidate_decode_char = random.choice(decode_list)

        can_decode = fsm_decoder.decode(candidate_decode_char)
        if can_decode:
            print("FSMDecoder State:", fsm_decoder.states)
            print("FSMDecoder Current State:", fsm_decoder.current_state)
            print()

            decode_word += candidate_decode_char

            if candidate_decode_char == "*":
                break

    print("Final Decode Word:", decode_word)

Output:

(0, 'h') -> 1
(1, 'o') -> 2
(2, 't') -> 3
(3, '*') -> -1
(0, 'c') -> 4
(4, 'o') -> 5
(5, 'l') -> 6
(6, 'd') -> 7
(7, '*') -> -1
(3, 'e') -> 8
(8, 'l') -> 9
(9, '*') -> -1

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 1

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 2

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 3

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: -1

Final Decode Word: hot*

We can see that the decoding states transition from 1 => 2 => 3 => -1, so the result is "hot."

Of course, one might ask: Why bother with a finite-state machine for decoding? Couldn't we just use a Trie? Theoretically, that's correct, but in practice, Outlines is designed to allow the user to input regular expression patterns to constrain model generation! In such a case, a Trie wouldn't suffice.


References


Read More

Leave a Reply