Last Updated on 2024-09-04 by Clay
本篇為一簡單 Python 實作,用於測試有限狀態機(Finite-State Machine)約束大型語言模型(Large Language Model)解碼出特定格式的回答,也權當作是理解 Outlines 工具的理念。當然,我實作的部份跟 Outlines 工具相比,實在顯得太過簡易。
首先我們快速複習一下 LLM 的解碼:
LLM 多為 decoder-only 的架構,專門為解碼而生。而解碼的單位為 Token 而非我們理解的字(當然現在比較多對於中日韓等文字的實現,已經變成一個 character 或是一個 word 就是一個 Token 了)。
而 LLM 在解碼的過程中,並不是一次推理會全部完成解碼,而是基於上下文接龍的形式,會把解碼出的 Token 拼接回輸入,繼續往下解碼下一個 Token,直到解碼出結束符號或是碰到系統設定的 Token 上限。
這也是為什麼我們會需要有限狀態機來完成對模型的約束。在模型的生成過程中,我們會需要判斷模型在有限狀態機的哪個『狀態』(state),並只接受基於這個狀態可以往下解碼的 Token。概念上,我認為與字典樹有異曲同工之妙。
程式實作
現在我定義一個有限狀態機的系統,只允許 LLM 解碼 hot、code、hotel 三個候選選項,並且我會以 * 作為結束符號。這是因為如果沒有設定結束符號的話,我會難以判斷解碼出 hot 時是否會需要往下解碼出 hotel。
首先我們定義出有限狀態機的狀態圖,圖中的每個圓圈都是狀態編號,邊線即是操作。在這裡,我比較違反正規地選擇了狀態 -1 為結束符號,我大概會被正統派的學者圍攻致死。
舉個例子,我們可以看到在狀態 3(state=3)時,只允許解碼 * 跟 e。如果是 *,最後解碼的結果就是 hot、反之則一定是 hotel。
以下是 Python 的實作,在這裡我並不是真正模擬解碼 Token、而是一次只解碼一個英文字母,a - z 和 A - Z:
from typing import List, Tuple, Dict
import random
class FSMDecoder:
def __init__(self, valid_words: List[str]) -> None:
self.valid_words = valid_words
self.states, self.final_states = self._create_fsm(valid_words=valid_words)
self.current_state = 0
def _create_fsm(self, valid_words: List[str]) -> Tuple[Dict[Tuple[int, str], int], set]:
# Init
states = {}
state_id = 0
root_state = 0
end_state_id = -1
for word in valid_words:
current_state = root_state
for char in word:
if char == "*":
states[(current_state, char)] = end_state_id
break
elif (current_state, char) not in states:
state_id += 1
states[(current_state, char)] = state_id
current_state = states[(current_state, char)]
final_states = {current_state for word in valid_words for char in word if '*' in word}
return states, final_states
def decode(self, char: str) -> bool:
if (self.current_state, char) in self.states:
self.current_state = self.states[(self.current_state, char)]
return True
else:
return False
def is_valid(self) -> bool:
return self.current_state in self.final_states
if __name__ == "__main__":
# Example usage
candidate_words = ["hot*", "cold*", "hotel*"]
fsm_decoder = FSMDecoder(candidate_words)
for key, value in fsm_decoder.states.items():
print(f"{key} -> {value}")
print()
# Assume `*` is the end special token
decode_list = ["*"]
decode_list.extend(list(map(chr, range(65, 91))))
decode_list.extend(list(map(chr, range(97, 123))))
# Try to random decoding!
decode_word = ""
while True:
candidate_decode_char = random.choice(decode_list)
can_decode = fsm_decoder.decode(candidate_decode_char)
if can_decode:
print("FSMDecoder State:", fsm_decoder.states)
print("FSMDecoder Current State:", fsm_decoder.current_state)
print()
decode_word += candidate_decode_char
if candidate_decode_char == "*":
break
print("Final Decode Word:", decode_word)
Output:
(0, 'h') -> 1
(1, 'o') -> 2
(2, 't') -> 3
(3, '*') -> -1
(0, 'c') -> 4
(4, 'o') -> 5
(5, 'l') -> 6
(6, 'd') -> 7
(7, '*') -> -1
(3, 'e') -> 8
(8, 'l') -> 9
(9, '*') -> -1
FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 1
FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 2
FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: 3
FSMDecoder State: {(0, 'h'): 1, (1, 'o'): 2, (2, 't'): 3, (3, '*'): -1, (0, 'c'): 4, (4, 'o'): 5, (5, 'l'): 6, (6, 'd'): 7, (7, '*'): -1, (3, 'e'): 8, (8, 'l'): 9, (9, '*'): -1}
FSMDecoder Current State: -1
Final Decode Word: hot*
我們可以看到解碼的狀態順序為 1 => 2 => 3 => -1,所以結果為 hot。
當然我們可能會想:這樣的解碼為什麼非得建立有限狀態機?建立個字典樹不就好了?這句話理論上是對的,不過實際上 Outlines 所設想的情境,甚至允許使用者輸入正規表示法的 pattern 來限制模型生成!這樣一來,字典樹可就不是和了。
References
- https://en.wikipedia.org/wiki/Finite-state_machine
- outlines-dev/outlines: Structured Text Generation