Skip to content

使用有限狀態機(FSM)和回滾(Rollback)機制控制 LLM 禁止生成禁止詞彙(Banned-Words)

Last Updated on 2024-10-28 by Clay

在大家透過 LLM 架設各種服務時,是否會煩惱 LLM 經常生成一些不受控制的言論?最近我的工作正在專案收尾的緊要關頭,我使用了 Outlines 等工具用以約束 LLM 解碼,確實能控制模型生成符合我想項中的模式 —— 但我同事突然對我發出靈魂一問:那我想要它不生成某些詞彙該怎麼辦?

這真是個好問題!最早我們家的 PM 也問過了同樣的問題,那時候由於手邊有其他更要緊的 deadline,我不假思索地告訴他說有點難;但直到此刻我才真正開始思考:其實理論上應該是做得到的。

2024/10/28 更新:目前我已經把原始碼放在:https://github.com/ccs96307/llm-decode-filter-special-words,歡迎大家不吝給出意見)


既有技術回顧

我先是從類似的工作開始搜尋起:

Logit BiasingToken Masking:有一些模型支援者直接修改 Token 的機率分佈,比方說 OpenAI 就有提供 logit_bias 參數來排除或增強某些 Token 的解碼。

Regex-based Constraints:如同前述,之前探究了使用 Outlines 工具結構化模型的輸出這類的方法,透過有限狀態機控制 LLM 當前的狀態,約束 LLM 解碼的範圍。

Direct Token Masking:最簡單直接的一種,透過遮蔽指定 Token,限制其不會生成。

不過對我來說,這不完全是我想要的功能,因為我要的限制是基於『詞彙』(Words)而非基於『詞元』(Tokens,如 ham-bur-ger 是一個詞彙,但卻由三個 Tokens 組成);比方說,我今天想要限制的髒話是 Fuck You,但是我不能限制 You,因為這會讓正常的對話中的 You 都無法順利解碼。

當然,根據我微弱的搜尋能力,可能並沒有搜尋到跟我完全一致的解決方法,在此也希望知道相似作法的朋友能夠分享給我,讓我有機會引用一下他們的介紹。


實作方法介紹

我想的方法是針對每個不同的模型來建立初始的『限制詞彙有限狀態機』—— 因為每個模型的詞彙表(Vocabulary)不盡相同。

首先是建立有限狀態機:

  1. 設定一組限制詞彙列表
  2. 使用模型的詞彙表,窮舉所有可能符合『限制詞彙』的第一個解碼 Token
  3. 從符合限制詞彙解碼的第一個 Token 往下,窮舉模型的詞彙表...... 依此類推,直到排列出所有會解碼出限制詞彙列表的 Tokens 解碼組合,並將 token_id 設置成有限狀態機,在每個詞彙組合的最後,將其轉移狀態設定為 -1,意味著『成功匹配

接著,開始進入真正的解碼環節:

  1. 假設模型解碼出的 Token 不符合有限狀態機狀態 0 的限制詞彙 Token,就繼續保持狀態 0;而一旦匹配某個限制詞彙 Token,則轉移到與之對應的下個狀態
  2. 若是下一個解碼不在當前狀態的限制 Token 列表中,則回到狀態 0;反之,則繼續狀態轉移 —— 而一旦『成功匹配』,狀態值移動到 -1,則進入『回滾模式』(Roll-back Mode
  3. 在回滾模式中,會回滾到最終所匹配的限制詞彙路徑中的第一個 Token 準備解碼前,在這裡禁止第一個 Token 的解碼,並永遠儲存在一個 {position_idx: [token_id]} 的列表中,當模型解碼第 position_idx 位置的 Token 時,會先去看這個字典是否有不能解碼的 Token ID,若有則將該 Token ID 遮蔽。

利用以上這個演算法,我們能夠完全禁止模型生成特定的詞彙。


實作方式

以下,我們以 Gemma-2 模型架構進行測試。當然我希望這種解碼演算法能夠支援的模型是非常多樣的,但是不同模型所使用的字典範圍不盡相同,我不能保證現有的一些規則可以直接套用所有模型。

首先先把所有的配置都準備好:

from typing import List, Dict, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Model and Tokenizer
pretrained_model_name_or_path = "./models/google--gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
model.to(device)



Step 1. 設定一組限制詞彙列表

在這裡,我把限制詞彙命名為 banned_words

banned_words = ["talk", "listen", "fuck you"]



Step 2. 使用模型的詞彙表,窮舉所有可能符合『限制詞彙』的生成路線

這裡有幾項需要補充的注意事項:並不是所有的 Token 在解碼後都可以直接匹配 banned_words,這是因為如 Gemma-2、GPT-2 這類的模型,許多都會有 ▁listen 這樣的 Token,而 則是代表『邊界』、『詞開頭』的符號,所以匹配 banned_words 時需要將其移除。

另外,在我們的字典中也會有 Token 帶有許多特殊符號,全部都要進行解碼路線的窮舉會造成計算量呈現指數型成長,所以使用 re.compile(r'^[\t\n\r\f\v\s]*$') 規則式匹配,匹配到後就將其刪除。

import re

new_token_candidates = []
new_token_id_candidates = []

# Define a pattern to filter out unwanted excessive whitespace tokens (e.g., multiple tabs or newlines)
invalid_token_pattern = re.compile(r'^[\t\n\r\f\v\s]*$')  # Matches tokens that are made of 3 or more whitespace characters

# Initial matching of tokens that start with banned word prefixes
for token, token_id in tokenizer.vocab.items():
    token_stripped = token.strip()
    if len(token_stripped) == 0 or (len(token_stripped) == 1 and token_stripped == "▁"): 
        continue

    if token_stripped[0] == "▁":
        token_stripped = token_stripped[1:]

    for banned_word in banned_words:
        if banned_word.startswith(token_stripped):
            new_token_candidates.append([token])  # Store as a list for consistency
            new_token_id_candidates.append([token_id])

print("Filtered First Candidates:", new_token_candidates)
print("Filtered First IDs:", new_token_id_candidates)

# Initialize final routes for fully matched banned words
final_routes = []
final_id_routes = []
for new_token, new_token_id in zip(new_token_candidates, new_token_id_candidates):
    new_token_str = "".join(new_token).strip()

    if new_token_str in banned_words or (new_token_str[0] == "▁" and new_token_str[1:] in banned_words):
        final_routes.append(new_token)
        final_id_routes.append(new_token_id)

print("(beginning) Final Routes:", final_routes)
print("(beginning) Final ID Routes:", final_id_routes)

# Iteratively expand candidates using BFS-like approach
while new_token_candidates:
    curr_token_candidates = new_token_candidates
    curr_token_id_candidates = new_token_id_candidates
    new_token_candidates = []
    new_token_id_candidates = []

    # Iterate over vocabulary and expand each candidate
    for token, token_id in tokenizer.vocab.items():
        token_stripped = token.strip()
        if len(token_stripped) == 0 or (len(token_stripped) == 1 and token_stripped == "▁") or re.findall(r"\▁{2,100}", token_stripped):
            continue

        # Skip tokens that consist of excessive whitespace or control characters
        if invalid_token_pattern.match(token_stripped):
            continue

        for candidate_token, candidate_token_id in zip(curr_token_candidates, curr_token_id_candidates):
            curr_token = candidate_token + [token]
            curr_token_ids = candidate_token_id + [token_id]

            curr_token_str = "".join([token.replace("▁", " ") for token in curr_token]).strip()
            # if curr_token_str[0] == "▁":
            #     curr_token_str = curr_token_str[1:]

            # Check if the current token combination matches or is a prefix of any banned word
            for banned_word in banned_words:
                if curr_token_str == banned_word:
                    # Full match found, add to final routes
                    final_routes.append(curr_token)
                    final_id_routes.append(curr_token_ids)
                elif banned_word.startswith(curr_token_str):
                    # Partial match, keep expanding this candidate
                    new_token_candidates.append(curr_token)
                    new_token_id_candidates.append(curr_token_ids)

    print(final_routes)

print("Final Routes:", final_routes)
print("Final ID Routes:", final_id_routes)


至此,我們的解碼生成路線可以表示為:

['listen'] [18998]
['▁listen'] [10724]
['▁talk'] [5063]
['talk'] [33085]
['▁tal', 'k'] [3412, 235273]
['tal', 'k'] [3559, 235273]
['li', 'sten'] [515, 5547]
['▁li', 'sten'] [702, 5547]
['ta', 'lk'] [516, 26159]
['▁ta', 'lk'] [3586, 26159]
['lis', 'ten'] [15063, 965]
['▁lis', 'ten'] [23966, 965]
['▁t', 'alk'] [474, 2071]
['t', 'alk'] [235251, 2071]
['▁liste', 'n'] [32165, 235254]
['liste', 'n'] [44003, 235254]
['▁list', 'en'] [1889, 479]
['list', 'en'] [1701, 479]
['l', 'isten'] [235257, 17071]
['▁l', 'isten'] [533, 17071]
['fuck', '▁you'] [34024, 692]
['▁fuck', '▁you'] [7935, 692]
['ta', 'l', 'k'] [516, 235257, 235273]
['▁ta', 'l', 'k'] [3586, 235257, 235273]
['▁t', 'al', 'k'] [474, 492, 235273]
['t', 'al', 'k'] [235251, 492, 235273]
['l', 'i', 'sten'] [235257, 235252, 5547]
['▁l', 'i', 'sten'] [533, 235252, 5547]
['▁t', 'a', 'lk'] [474, 235250, 26159]
['t', 'a', 'lk'] [235251, 235250, 26159]
['l', 'is', 'ten'] [235257, 502, 965]
['▁l', 'is', 'ten'] [533, 502, 965]
['li', 's', 'ten'] [515, 235256, 965]
['▁li', 's', 'ten'] [702, 235256, 965]
['fuck', '▁y', 'ou'] [34024, 597, 507]
['▁fuck', '▁y', 'ou'] [7935, 597, 507]
['l', 'iste', 'n'] [235257, 3671, 235254]
['▁l', 'iste', 'n'] [533, 3671, 235254]
['li', 'ste', 'n'] [515, 2855, 235254]
['▁li', 'ste', 'n'] [702, 2855, 235254]
['lis', 'te', 'n'] [15063, 488, 235254]
['▁lis', 'te', 'n'] [23966, 488, 235254]
['▁list', 'e', 'n'] [1889, 235249, 235254]
['list', 'e', 'n'] [1701, 235249, 235254]
['l', 'ist', 'en'] [235257, 694, 479]
['▁l', 'ist', 'en'] [533, 694, 479]
['li', 'st', 'en'] [515, 490, 479]
['▁li', 'st', 'en'] [702, 490, 479]
['lis', 't', 'en'] [15063, 235251, 479]
['▁lis', 't', 'en'] [23966, 235251, 479]
['▁fu', 'ck', '▁you'] [4936, 623, 692]
['fu', 'ck', '▁you'] [12819, 623, 692]
['▁fuc', 'k', '▁you'] [79433, 235273, 692]
['▁f', 'uck', '▁you'] [517, 1870, 692]
['f', 'uck', '▁you'] [235266, 1870, 692]
['fuck', '▁yo', 'u'] [34024, 10931, 235261]
['▁fuck', '▁yo', 'u'] [7935, 10931, 235261]
['▁t', 'a', 'l', 'k'] [474, 235250, 235257, 235273]
['t', 'a', 'l', 'k'] [235251, 235250, 235257, 235273]
['l', 'i', 's', 'ten'] [235257, 235252, 235256, 965]
['▁l', 'i', 's', 'ten'] [533, 235252, 235256, 965]
['▁fu', 'ck', '▁y', 'ou'] [4936, 623, 597, 507]
['fu', 'ck', '▁y', 'ou'] [12819, 623, 597, 507]
['▁fuc', 'k', '▁y', 'ou'] [79433, 235273, 597, 507]
['▁f', 'uck', '▁y', 'ou'] [517, 1870, 597, 507]
['f', 'uck', '▁y', 'ou'] [235266, 1870, 597, 507]
['l', 'i', 'ste', 'n'] [235257, 235252, 2855, 235254]
['▁l', 'i', 'ste', 'n'] [533, 235252, 2855, 235254]
['l', 'is', 'te', 'n'] [235257, 502, 488, 235254]
['▁l', 'is', 'te', 'n'] [533, 502, 488, 235254]
['li', 's', 'te', 'n'] [515, 235256, 488, 235254]
['▁li', 's', 'te', 'n'] [702, 235256, 488, 235254]
['l', 'ist', 'e', 'n'] [235257, 694, 235249, 235254]
['▁l', 'ist', 'e', 'n'] [533, 694, 235249, 235254]
['li', 'st', 'e', 'n'] [515, 490, 235249, 235254]
['▁li', 'st', 'e', 'n'] [702, 490, 235249, 235254]
['lis', 't', 'e', 'n'] [15063, 235251, 235249, 235254]
['▁lis', 't', 'e', 'n'] [23966, 235251, 235249, 235254]
['l', 'i', 'st', 'en'] [235257, 235252, 490, 479]
['▁l', 'i', 'st', 'en'] [533, 235252, 490, 479]
['l', 'is', 't', 'en'] [235257, 502, 235251, 479]
['▁l', 'is', 't', 'en'] [533, 502, 235251, 479]
['li', 's', 't', 'en'] [515, 235256, 235251, 479]
['▁li', 's', 't', 'en'] [702, 235256, 235251, 479]
['▁f', 'u', 'ck', '▁you'] [517, 235261, 623, 692]
['f', 'u', 'ck', '▁you'] [235266, 235261, 623, 692]
['▁fu', 'c', 'k', '▁you'] [4936, 235260, 235273, 692]
['fu', 'c', 'k', '▁you'] [12819, 235260, 235273, 692]
['▁f', 'uc', 'k', '▁you'] [517, 1669, 235273, 692]
['f', 'uc', 'k', '▁you'] [235266, 1669, 235273, 692]
['fuck', '▁y', 'o', 'u'] [34024, 597, 235253, 235261]
['▁fuck', '▁y', 'o', 'u'] [7935, 597, 235253, 235261]
['▁fu', 'ck', '▁yo', 'u'] [4936, 623, 10931, 235261]
['fu', 'ck', '▁yo', 'u'] [12819, 623, 10931, 235261]
['▁fuc', 'k', '▁yo', 'u'] [79433, 235273, 10931, 235261]
['▁f', 'uck', '▁yo', 'u'] [517, 1870, 10931, 235261]
['f', 'uck', '▁yo', 'u'] [235266, 1870, 10931, 235261]
['▁f', 'u', 'ck', '▁y', 'ou'] [517, 235261, 623, 597, 507]
['f', 'u', 'ck', '▁y', 'ou'] [235266, 235261, 623, 597, 507]
['▁fu', 'c', 'k', '▁y', 'ou'] [4936, 235260, 235273, 597, 507]
['fu', 'c', 'k', '▁y', 'ou'] [12819, 235260, 235273, 597, 507]
['▁f', 'uc', 'k', '▁y', 'ou'] [517, 1669, 235273, 597, 507]
['f', 'uc', 'k', '▁y', 'ou'] [235266, 1669, 235273, 597, 507]
['l', 'i', 's', 'te', 'n'] [235257, 235252, 235256, 488, 235254]
['▁l', 'i', 's', 'te', 'n'] [533, 235252, 235256, 488, 235254]
['l', 'i', 'st', 'e', 'n'] [235257, 235252, 490, 235249, 235254]
['▁l', 'i', 'st', 'e', 'n'] [533, 235252, 490, 235249, 235254]
['l', 'is', 't', 'e', 'n'] [235257, 502, 235251, 235249, 235254]
['▁l', 'is', 't', 'e', 'n'] [533, 502, 235251, 235249, 235254]
['li', 's', 't', 'e', 'n'] [515, 235256, 235251, 235249, 235254]
['▁li', 's', 't', 'e', 'n'] [702, 235256, 235251, 235249, 235254]
['l', 'i', 's', 't', 'en'] [235257, 235252, 235256, 235251, 479]
['▁l', 'i', 's', 't', 'en'] [533, 235252, 235256, 235251, 479]
['▁f', 'u', 'c', 'k', '▁you'] [517, 235261, 235260, 235273, 692]
['f', 'u', 'c', 'k', '▁you'] [235266, 235261, 235260, 235273, 692]
['▁fu', 'ck', '▁y', 'o', 'u'] [4936, 623, 597, 235253, 235261]
['fu', 'ck', '▁y', 'o', 'u'] [12819, 623, 597, 235253, 235261]
['▁fuc', 'k', '▁y', 'o', 'u'] [79433, 235273, 597, 235253, 235261]
['▁f', 'uck', '▁y', 'o', 'u'] [517, 1870, 597, 235253, 235261]
['f', 'uck', '▁y', 'o', 'u'] [235266, 1870, 597, 235253, 235261]
['▁f', 'u', 'ck', '▁yo', 'u'] [517, 235261, 623, 10931, 235261]
['f', 'u', 'ck', '▁yo', 'u'] [235266, 235261, 623, 10931, 235261]
['▁fu', 'c', 'k', '▁yo', 'u'] [4936, 235260, 235273, 10931, 235261]
['fu', 'c', 'k', '▁yo', 'u'] [12819, 235260, 235273, 10931, 235261]
['▁f', 'uc', 'k', '▁yo', 'u'] [517, 1669, 235273, 10931, 235261]
['f', 'uc', 'k', '▁yo', 'u'] [235266, 1669, 235273, 10931, 235261]
['▁f', 'u', 'c', 'k', '▁y', 'ou'] [517, 235261, 235260, 235273, 597, 507]
['f', 'u', 'c', 'k', '▁y', 'ou'] [235266, 235261, 235260, 235273, 597, 507]
['l', 'i', 's', 't', 'e', 'n'] [235257, 235252, 235256, 235251, 235249, 235254]
['▁l', 'i', 's', 't', 'e', 'n'] [533, 235252, 235256, 235251, 235249, 235254]
['▁f', 'u', 'ck', '▁y', 'o', 'u'] [517, 235261, 623, 597, 235253, 235261]
['f', 'u', 'ck', '▁y', 'o', 'u'] [235266, 235261, 623, 597, 235253, 235261]
['▁fu', 'c', 'k', '▁y', 'o', 'u'] [4936, 235260, 235273, 597, 235253, 235261]
['fu', 'c', 'k', '▁y', 'o', 'u'] [12819, 235260, 235273, 597, 235253, 235261]
['▁f', 'uc', 'k', '▁y', 'o', 'u'] [517, 1669, 235273, 597, 235253, 235261]
['f', 'uc', 'k', '▁y', 'o', 'u'] [235266, 1669, 235273, 597, 235253, 235261]
['▁f', 'u', 'c', 'k', '▁yo', 'u'] [517, 235261, 235260, 235273, 10931, 235261]
['f', 'u', 'c', 'k', '▁yo', 'u'] [235266, 235261, 235260, 235273, 10931, 235261]
['▁f', 'u', 'c', 'k', '▁y', 'o', 'u'] [517, 235261, 235260, 235273, 597, 235253, 235261]
['f', 'u', 'c', 'k', '▁y', 'o', 'u'] [235266, 235261, 235260, 235273, 597, 235253, 235261]


而透過我們窮舉的路線,我們可以實現一個有限狀態機去隨時偵測當前狀態有哪些 Token ID 是不被我們所允許的。有限狀態機的實現細節,可以參考之前我寫過的:使用有限狀態機約束大型語言模型解碼之實作。不過那也是一篇比較簡易的實作。

class FSMProcessor:
    def __init__(self, special_token_ids_list: List[List[int]], end_state: int = -1) -> None:
        self.end_state = end_state
        self.next_state = 1
        self.curr_state = 0
        self.fsm = {}
        self.special_words = []

        # Track partial matches
        self.partial_match_state = None
        self.partial_tokens = []

        self.update_group(special_token_ids_list)

    def update(self, special_token_ids: List[int]) -> None:
        curr_state = 0

        for idx, special_token_id in enumerate(special_token_ids):
            if curr_state not in self.fsm:
                self.fsm[curr_state] = []

            state2id = [items[0] for items in self.fsm[curr_state]]
            if special_token_id not in state2id:
                if idx == len(special_token_ids) - 1:
                    self.fsm[curr_state].append([special_token_id, self.end_state])
                else:
                    self.fsm[curr_state].append([special_token_id, self.next_state])
                    curr_state = self.next_state
                    self.next_state += 1
            else:
                for fsm_idx in range(len(self.fsm[curr_state])):
                    if special_token_id == self.fsm[curr_state][fsm_idx][0] and idx == len(special_token_ids) - 1:
                        self.fsm[curr_state][fsm_idx][1] = self.end_state
                        break
                    elif special_token_id == self.fsm[curr_state][fsm_idx][0]:
                        curr_state = self.fsm[curr_state][fsm_idx][1]
                        break

    def update_group(self, special_token_ids_list: List[List[int]]) -> None:
        for special_token_ids in special_token_ids_list:
            self.update(special_token_ids=special_token_ids)

    def get_fsm_data(self) -> Dict[str, List[Tuple[int, int]]]:
        return self.fsm
    
    def detect(self, token: int) -> bool:
        """
        Detect if the current token leads to a sensitive sequence.
        Updates the current state and returns True if it reaches the end state.
        """
        if self.curr_state in self.fsm:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == token:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state
        
        # If the token does not match, reset the current state
        self.curr_state = 0
        return False
fsm_processor = FSMProcessor(special_token_ids_list=final_id_routes)
fsm_processor.get_fsm_data()


Output:

{0: [[18998, -1],
[10724, -1],
[5063, -1],
[33085, -1],
[3412, 1],
[3559, 2],
[515, 3],
[702, 4],
[516, 5],
[3586, 6],
[15063, 7],
[23966, 8],
[474, 9],
[235251, 10],
[32165, 11],
[44003, 12],
[1889, 13],
[1701, 14],
[235257, 15],
[533, 16],
[34024, 17],
[7935, 18],
[4936, 47],
[12819, 49],
[79433, 51],
[517, 53],
[235266, 55]],
1: [[235273, -1]],
2: [[235273, -1]],
3: [[5547, -1], [235256, 29], [2855, 35], [490, 43]],
4: [[5547, -1], [235256, 30], [2855, 36], [490, 44]],
5: [[26159, -1], [235257, 19]],
6: [[26159, -1], [235257, 20]],
7: [[965, -1], [488, 37], [235251, 45]],
8: [[965, -1], [488, 38], [235251, 46]],
9: [[2071, -1], [492, 21], [235250, 25]],
10: [[2071, -1], [492, 22], [235250, 26]],
11: [[235254, -1]],
12: [[235254, -1]],
13: [[479, -1], [235249, 39]],
14: [[479, -1], [235249, 40]],
15: [[17071, -1], [235252, 23], [502, 27], [3671, 33], [694, 41]],
16: [[17071, -1], [235252, 24], [502, 28], [3671, 34], [694, 42]],
17: [[692, -1], [597, 31], [10931, 57]],
18: [[692, -1], [597, 32], [10931, 58]],
19: [[235273, -1]],
20: [[235273, -1]],
21: [[235273, -1]],
22: [[235273, -1]],
23: [[5547, -1], [235256, 61], [2855, 68], [490, 80]],
24: [[5547, -1], [235256, 62], [2855, 69], [490, 81]],
25: [[26159, -1], [235257, 59]],
26: [[26159, -1], [235257, 60]],
27: [[965, -1], [488, 70], [235251, 82]],
28: [[965, -1], [488, 71], [235251, 83]],
29: [[965, -1], [488, 72], [235251, 84]],
30: [[965, -1], [488, 73], [235251, 85]],
31: [[507, -1], [235253, 98]],
32: [[507, -1], [235253, 99]],
33: [[235254, -1]],
34: [[235254, -1]],
35: [[235254, -1]],
36: [[235254, -1]],
37: [[235254, -1]],
38: [[235254, -1]],
39: [[235254, -1]],
40: [[235254, -1]],
41: [[479, -1], [235249, 74]],
42: [[479, -1], [235249, 75]],
43: [[479, -1], [235249, 76]],
44: [[479, -1], [235249, 77]],
45: [[479, -1], [235249, 78]],
46: [[479, -1], [235249, 79]],
47: [[623, 48], [235260, 90]],
48: [[692, -1], [597, 63], [10931, 100]],
49: [[623, 50], [235260, 92]],
50: [[692, -1], [597, 64], [10931, 101]],
51: [[235273, 52]],
52: [[692, -1], [597, 65], [10931, 102]],
53: [[1870, 54], [235261, 86], [1669, 94]],
54: [[692, -1], [597, 66], [10931, 103]],
55: [[1870, 56], [235261, 88], [1669, 96]],
56: [[692, -1], [597, 67], [10931, 104]],
57: [[235261, -1]],
58: [[235261, -1]],
59: [[235273, -1]],
60: [[235273, -1]],
61: [[965, -1], [488, 111], [235251, 119]],
62: [[965, -1], [488, 112], [235251, 120]],
63: [[507, -1], [235253, 125]],
64: [[507, -1], [235253, 126]],
65: [[507, -1], [235253, 127]],
66: [[507, -1], [235253, 128]],
67: [[507, -1], [235253, 129]],
68: [[235254, -1]],
69: [[235254, -1]],
70: [[235254, -1]],
71: [[235254, -1]],
72: [[235254, -1]],
73: [[235254, -1]],
74: [[235254, -1]],
75: [[235254, -1]],
76: [[235254, -1]],
77: [[235254, -1]],
78: [[235254, -1]],
79: [[235254, -1]],
80: [[479, -1], [235249, 113]],
81: [[479, -1], [235249, 114]],
82: [[479, -1], [235249, 115]],
83: [[479, -1], [235249, 116]],
84: [[479, -1], [235249, 117]],
85: [[479, -1], [235249, 118]],
86: [[623, 87], [235260, 121]],
87: [[692, -1], [597, 105], [10931, 130]],
88: [[623, 89], [235260, 123]],
89: [[692, -1], [597, 106], [10931, 131]],
90: [[235273, 91]],
91: [[692, -1], [597, 107], [10931, 132]],
92: [[235273, 93]],
93: [[692, -1], [597, 108], [10931, 133]],
94: [[235273, 95]],
95: [[692, -1], [597, 109], [10931, 134]],
96: [[235273, 97]],
97: [[692, -1], [597, 110], [10931, 135]],
98: [[235261, -1]],
99: [[235261, -1]],
100: [[235261, -1]],
101: [[235261, -1]],
102: [[235261, -1]],
103: [[235261, -1]],
104: [[235261, -1]],
105: [[507, -1], [235253, 140]],
106: [[507, -1], [235253, 141]],
107: [[507, -1], [235253, 142]],
108: [[507, -1], [235253, 143]],
109: [[507, -1], [235253, 144]],
110: [[507, -1], [235253, 145]],
111: [[235254, -1]],
112: [[235254, -1]],
113: [[235254, -1]],
114: [[235254, -1]],
115: [[235254, -1]],
116: [[235254, -1]],
117: [[235254, -1]],
118: [[235254, -1]],
119: [[479, -1], [235249, 138]],
120: [[479, -1], [235249, 139]],
121: [[235273, 122]],
122: [[692, -1], [597, 136], [10931, 146]],
123: [[235273, 124]],
124: [[692, -1], [597, 137], [10931, 147]],
125: [[235261, -1]],
126: [[235261, -1]],
127: [[235261, -1]],
128: [[235261, -1]],
129: [[235261, -1]],
130: [[235261, -1]],
131: [[235261, -1]],
132: [[235261, -1]],
133: [[235261, -1]],
134: [[235261, -1]],
135: [[235261, -1]],
136: [[507, -1], [235253, 148]],
137: [[507, -1], [235253, 149]],
138: [[235254, -1]],
139: [[235254, -1]],
140: [[235261, -1]],
141: [[235261, -1]],
142: [[235261, -1]],
143: [[235261, -1]],
144: [[235261, -1]],
145: [[235261, -1]],
146: [[235261, -1]],
147: [[235261, -1]],
148: [[235261, -1]],
149: [[235261, -1]]}



Step 3. 正常的解碼 vs FSM + Roll-back 機制的解碼

以下我們來看正常解碼的情況,這是一個 Greedy Search 的情況:

def custom_generate(input_ids: torch.Tensor, max_length: int = 50) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids.to(device), return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]
        
        new_generated_token_id = torch.argmax(logits, dim=-1)

        if new_generated_token_id == tokenizer.eos_token_id:
            break

        input_ids = torch.cat((input_ids, new_generated_token_id.unsqueeze(0)), dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)


input_text = "Can we talk?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
print(custom_generate(input_ids=inputs.input_ids))


Output:

<bos>Can we talk?

I'm here to listen and help in any way I can.

What's on your mind?
<end_of_turn>


我們可以看到,模型回答得相當中規中矩。以下我們嘗試加入我們的 banned-words 來限制模型解碼,我們的限制詞會有:["talk", "listen", "fuck you"]

def custom_generate_with_fsm_filter(
    input_ids: torch.Tensor,
    fsm_processor: FSMProcessor,
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}
    past_key_values = None
    steps = 0

    while steps < max_length:
        steps += 1

        # Generate new token with kv cache
        outputs = model(input_ids, past_key_values=past_key_values, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Update kv cache
        past_key_values = outputs.past_key_values

        # Check if there are already masked tokens at the current step
        if steps in masked_tokens_history:
            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")
        else:
            masked_tokens_history[steps] = set()

        # Decode the generated token
        generated_token_id = torch.argmax(logits, dim=-1).item()
        combined_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=-1)

        # Check FSM for sensitive sequences
        if fsm_processor.detect(generated_token_id):
            # Detected a sensitive sequence, initiate rollback
            rollback_length = fsm_processor.partial_match_state + 1 if fsm_processor.partial_match_state is not None else 1
            steps = steps - rollback_length + 1
            rollbacks_ids = combined_ids[:, :-rollback_length]
            input_ids = rollbacks_ids
            print(f"Rollback detected. Rolling back from step {steps + rollback_length} to step {steps}")

            # Reset FSM state
            fsm_processor.curr_state = 0
            fsm_processor.partial_match_state = None

            # Reset past_key_values when rolling back
            past_key_values = None

            # Recalculate logits based on rolled-back sequence
            outputs = model(input_ids, return_dict=True, use_cache=True)
            logits = outputs.logits[:, -1, :]
            past_key_values = outputs.past_key_values

            # Mask the first token of the sensitive sequence
            first_token_id = generated_token_id
            print(f"Masking token id: {first_token_id}, Masking token: {tokenizer.decode(first_token_id)}")

            # Update the historical mask list to record the token at this step
            masked_tokens_history[steps].add(first_token_id)

            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")

            # Generate the token again after masking
            generated_token_id = torch.argmax(logits, dim=-1).item()

        # Update input_ids with the generated token
        input_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=1)

        print(f"Step {steps}: ID: {generated_token_id} Generated token: {tokenizer.decode(generated_token_id)}")

        if generated_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)


response = custom_generate_with_fsm_filter(
    input_ids=inputs.input_ids,
    fsm_processor=fsm_processor,
    max_length=50,
)
print(response)


Output:

Step 1: ID: 109 Generated token: 


Step 2: ID: 235285 Generated token: I
Step 3: ID: 235303 Generated token: '
Step 4: ID: 235262 Generated token: m
Step 5: ID: 1517 Generated token: here
Step 6: ID: 577 Generated token: to
Rollback detected. Rolling back from step 8 to step 7
Masking token id: 10724, Masking token: listen
Step 7: ID: 1707 Generated token: help
Step 8: ID: 692 Generated token: you
Step 9: ID: 675 Generated token: with
Step 10: ID: 9550 Generated token: whatever
Step 11: ID: 692 Generated token: you
Step 12: ID: 1476 Generated token: need
Step 13: ID: 235265 Generated token: .
Step 14: ID: 235248 Generated token:
Step 15: ID: 109 Generated token:


Step 16: ID: 5958 Generated token: Please
Step 17: ID: 3337 Generated token: tell
Step 18: ID: 682 Generated token: me
Step 19: ID: 1212 Generated token: what
Step 20: ID: 235303 Generated token: '
Step 21: ID: 235256 Generated token: s
Step 22: ID: 611 Generated token: on
Step 23: ID: 861 Generated token: your
Step 24: ID: 3403 Generated token: mind
Step 25: ID: 235265 Generated token: .
Step 26: ID: 44416 Generated token: 😊
Step 27: ID: 108 Generated token:

Step 28: ID: 107 Generated token: <end_of_turn>
Step 29: ID: 1 Generated token: <eos>

<bos>Can we talk?

I'm here to help you with whatever you need.

Please tell me what's on your mind. 😊
<end_of_turn><eos>


我們可以看到在原本解碼 ▁listen 的時候,現在被我們的 FSM 捕捉到,並透過回滾機制(Roll-back mechanism)回滾到了解碼前,並被禁止解碼 Token ID=10724 的 Token。

以上,正是一個簡易的利用有限狀態機加上回滾機制的約束解碼控制不要生成 banned words 的演算法;當然目前應該還存在著一些例外狀況與邊際案例,我會趁著有空時繼續優化它。

感謝閱讀到這裡的每一位讀者,也歡迎不吝提出意見,感謝!


References


Read More

Leave a Reply