Skip to content

使用有限狀態機約束大型語言模型解碼之實作

Last Updated on 2024-09-04 by Clay

本篇為一簡單 Python 實作,用於測試有限狀態機Finite-State Machine)約束大型語言模型Large Language Model)解碼出特定格式的回答,也權當作是理解 Outlines 工具的理念。當然,我實作的部份跟 Outlines 工具相比,實在顯得太過簡易。

首先我們快速複習一下 LLM 的解碼:

LLM 多為 decoder-only 的架構,專門為解碼而生。而解碼的單位為 Token 而非我們理解的字(當然現在比較多對於中日韓等文字的實現,已經變成一個 character 或是一個 word 就是一個 Token 了)。

而 LLM 在解碼的過程中,並不是一次推理會全部完成解碼,而是基於上下文接龍的形式,會把解碼出的 Token 拼接回輸入,繼續往下解碼下一個 Token,直到解碼出結束符號或是碰到系統設定的 Token 上限。

這也是為什麼我們會需要有限狀態機來完成對模型的約束。在模型的生成過程中,我們會需要判斷模型在有限狀態機的哪個『狀態』(state),並只接受基於這個狀態可以往下解碼的 Token。概念上,我認為與字典樹有異曲同工之妙。


程式實作

現在我定義一個有限狀態機的系統,只允許 LLM 解碼 hot、code、hotel 三個候選選項,並且我會以 * 作為結束符號。這是因為如果沒有設定結束符號的話,我會難以判斷解碼出 hot 時是否會需要往下解碼出 hotel。

首先我們定義出有限狀態機的狀態圖,圖中的每個圓圈都是狀態編號,邊線即是操作。在這裡,我比較違反正規地選擇了狀態 -1 為結束符號,我大概會被正統派的學者圍攻致死。

舉個例子,我們可以看到在狀態 3(state=3)時,只允許解碼 * 跟 e。如果是 *,最後解碼的結果就是 hot、反之則一定是 hotel。

以下是 Python 的實作,在這裡我並不是真正模擬解碼 Token、而是一次只解碼一個英文字母,a - z 和 A - Z:

from typing import List, Tuple, Dict
import random


class FSMDecoder:
    def __init__(self, valid_words: List[str]) -> None:
        self.valid_words = valid_words
        self.states, self.final_states = self._create_fsm(valid_words=valid_words)
        self.current_state = 0

    def _create_fsm(self, valid_words: List[str]) -> Tuple[Dict[Tuple[int, str], int], set]:
        # Init
        states = {}
        state_id = 0
        root_state = 0
        end_state_id = -1

        for word in valid_words:
            current_state = root_state
            for char in word:
                if char == "*":
                    states[(current_state, char)] = end_state_id
                    break
                elif (current_state, char) not in states:
                    state_id += 1
                    states[(current_state, char)] = state_id
                current_state = states[(current_state, char)]
        
        final_states = {current_state for word in valid_words for char in word if '*' in word}
        
        return states, final_states

    def decode(self, char: str) -> bool:
        if (self.current_state, char) in self.states:
            self.current_state = self.states[(self.current_state, char)]
            return True
        else:
            return False

    def is_valid(self) -> bool:
        return self.current_state in self.final_states


if __name__ == "__main__":
    # Example usage
    candidate_words = ["hot*", "cold*", "hotel*"]

    fsm_decoder = FSMDecoder(candidate_words)

    for key, value in fsm_decoder.states.items():
        print(f"{key} -> {value}")

    print()

    # Assume `*` is the end special token
    decode_list = ["*"]

    decode_list.extend(list(map(chr, range(65, 91))))
    decode_list.extend(list(map(chr, range(97, 123))))

    # Try to random decoding!
    decode_word = ""

    while True:
        candidate_decode_char = random.choice(decode_list)

        can_decode = fsm_decoder.decode(candidate_decode_char)
        if can_decode:
            print("FSMDecoder State:", fsm_decoder.states)
            print("FSMDecoder Current State:", fsm_decoder.current_state)
            print()

            decode_word += candidate_decode_char

            if candidate_decode_char == "*":
                break

    print("Final Decode Word:", decode_word)


Output:

(0, 'h') -> 1
(1, 'o') -> 2
(2, 't') -> 3
(3, '*') -> -1
(0, 'c') -> 4
(4, 'o') -> 5
(5, 'l') -> 6
(6, 'd') -> 7
(7, '*') -> -1
(3, 'e') -> 8
(8, 'l') -> 9
(9, '*') -> -1

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 1

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 2

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 3

FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: -1

Final Decode Word: hot*


我們可以看到解碼的狀態順序為 1 => 2 => 3 => -1,所以結果為 hot。

當然我們可能會想:這樣的解碼為什麼非得建立有限狀態機?建立個字典樹不就好了?這句話理論上是對的,不過實際上 Outlines 所設想的情境,甚至允許使用者輸入正規表示法的 pattern 來限制模型生成!這樣一來,字典樹可就不是和了。


References


Read More

Leave a Reply