Skip to content

[PyTorch] BERT 架構實現筆記

前言

以前我的指導教授常常告訴我,不要僅僅只是使用別人的套件,一定要自己寫過才會有感覺。當時我沒有太多的時間去實現各種我感興趣的技術,光是拼出論文就已經竭盡全力了。但是直到現在仍時常回想教授的諄諄教誨,忍不住開始動手實現 BERT 這一經典架構的 encoder-only transformer 模型。

一方面,也是因為前陣子與同事討論了 Transformer 架構中的 PAD 特殊符號在前向傳播中的意義、也討論了多頭注意力機制(Multi-head Attention Mechanism)提取特徵的細節 —— 我才終於再次回想起,其實 Transformer 仍然充滿著許多不同的變體與細節是我所不夠熟悉的,重新審視自己對這一經典模型架構的理解是有其意義且必要的。

所以我花了三週左右,不看 HuggingFace transformers 函式庫是如何實現,只看最終輸出結果,查閱網路上不同資料與圖表,斷斷續續地完成了 PyTorch 的 BERT 模型實現。

當然我必須坦承地說明,我所完成的 BERT 架構在和 HuggingFace transformers 函式庫中的 BERT 模型在最終輸出的精確度上只有逼近到小數點第五位,從第六位開始就存在著誤差;並且我某部份的實現幾乎與 transformers 中的 bert_modeling.py 近乎一致,也是因為我在查看各種資料時,可能參考的寫法恰好是從 bert_modeling.py 中引用的。

以下我就簡單介紹我的實現方法,如有謬誤之處,還請不吝指出,感謝!最後,我的原始碼開源在 GitHub 上,歡迎參考。


BERT 模型簡介

BERT 模型最早是由 Google 在 2018 年時的論文《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》中所提出的。在 BERT 出來之後,幾乎大部分主流的 NLP benchmarks 通通都被瘋狂刷榜,實在是 BERT 的性能太好、太強所導致。

也正因為如此,當時掀起了一波 BERT 風潮,幾乎各種 domain 的提取特徵模型、分類模型都可以訓練出一套自己的 BERT,導致 BERT 的變體超級無敵多,跟現在當紅的 GPT 一樣(2024 年 2 月)。

BERT 也遵循著經典的 pretraining + fine-tuning 兩階段,第一階段由兩種的非監督式學習來提取各種文本的特徵資訊(分別是預測被遮罩的 MASK token 和次句預測任務(Next Sentence Prediction, NSP)),學習完畢後,BERT 對抽取特徵資訊已經有了不錯的能力,接著就是在開發者側重的任務上進行下游任務(down-stream task)的微調了。

也正因為經歷過大量資料的預訓練階段,進而讓 BERT 在下游任務的微調中可以不必接受大量資料的洗禮,以較少的資料量迅速適應、收斂到不錯的表現區間。

而 BERT 的架構放在現在也非常單純,就是純粹的自注意力機制(Self-Attention Mechanism)和前饋神經網路(Feed-Forward Neural Network)的組合。

以下,我們就來實現 BERT 的實際架構吧!目標是可以讀入 HuggingFace 上的預訓練模型權重,並且與其 transformers 實現的 BertModel 的輸出在一定的精度上一致。


BERT 實現

在實現不同模組時,可以反覆確認我上方所繪製的架構圖。

匯入所有需要用到之套件

在當前實作中,定義了 BertConfig 以及 BertModel 兩個類別,而沒有實作 BertTokenizer。所以在我的測試建構模型程式碼中,會直接使用 transformers 套件的 BertTokenizer 作為斷詞器。

from typing import OrderedDict, Optional, List, Tuple

from dataclasses import dataclass
import json
import math

import torch
from transformers import AutoModel, BertTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file



BertConfig

以下的設定類別中的屬性,我依據 bert-tiny 以及 bert-base 去設計的,並且是預設兼容 transformers 中的 from_pretrained() 方法;然後若是有些 HuggingFace 平台上的 BERT 模型設定檔有更多的參數,則可能會導致資料無法初始化。

但作為最基本的實現(練習用途)而非實務開發,我就沒做錯誤預防處理。

@dataclass
class BertConfig:
    architectures: Optional[List[str]] = None
    attention_probs_dropout_prob: float = 0.1
    gradient_checkpointing: Optional[bool] = False
    classifier_dropout: Optional[float] = None
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    hidden_size: int = 128
    initializer_range: float = 0.02
    intermediate_size: int = 512
    layer_norm_eps: float = 1e-12
    max_position_embeddings: int = 512
    model_type: str = "bert"
    num_attention_heads: int = 2
    num_hidden_layers: int = 2
    pad_token_id: int = 0
    position_embedding_type: str = "absolute"
    transformers_version: str = "4.36.2"
    type_vocab_size: int = 2
    use_cache: bool = True
    vocab_size: int = 30522

    @staticmethod
    def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
        resolved_archive_file = cached_file(
            path_or_repo_id=pretrained_model_name_or_path,
            filename=CONFIG_NAME,
            _raise_exceptions_for_missing_entries=False,
        )
        config_content = json.load(open(resolved_archive_file))
        return BertConfig(**config_content)



輸出樣本類別

這一段其實對於 BERT 模型架構的實現是沒有幫助的,只是為了與 transformers 中的 BertModel 類別的輸出格式對齊加上的。

BaseModelOutputWithPastAndCrossAttentions 中間某次傳遞的回覆格式,而 BaseModelOutputWithPoolingAndCrossAttentions 則是最後模型的輸出格式。

@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None



BertEmbedding

終於來到 BERT 模型架構的第一層了。在我們的輸入進入 BERT 模型後,第一個遇到的神經網路層便是所謂的嵌入層(Embedding Layer)。

嵌入層的公用就是我們輸入特定的 ID 取得特定表中的向量(或稱嵌入)輸出,這些嵌入是有意義的,並且由於是向量,所以一定是數值化的資料。

假設我們想要取『今天天氣真好』的向量,我們可能會斷詞成 token,轉換成對應的 ID 序列(這裡只是假設, 一般中文字會有 1 – 3 tokens

["今", "天", "天", "氣", "真", "好"] =>
[1, 2, 2, 3, 4 ,4]


接著再透過查表的方式,從 ID 找出其對應的向量。這就是 embedding layer 在做的工作。

而在 BERT 模型中,除了文字序列的輸入外,還有位置資訊(0 – seq_len),以及 token 類型(token type)。

位置資訊是由 position_embedding 是用來表示位置資訊的。在 BERT 當中這是一個可訓練神經層。

token type 是個經常被忽略的資訊,它是由 BERT 所做的預訓練任務次句預測(Next Sentence Prediction, NSP)而來的。

在次句預測中,BERT 需要判斷 tokens 皆標註為 0 的句子和 tokens 皆標註為 1 的句子,是否是上下文關係,所以這裡出現了不同的 token type 轉換向量。

若是在 cross-encoder 的任務中,我們又會再次需要這個 embedding layer 來辨認不同的句子;但是在一般只有單一輸入的情況下,全部標示為 0 即可。

最後將三組不同的 embeddings 相加,然後通過層歸一化(Layer Norm)以及 dropout 輸出。

class BertEmbeddings(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

        word_embedding = self.word_embeddings(input_ids)
        token_type_embedding = self.token_type_embeddings(token_type_ids)
        position_embedding = self.position_embeddings(position_ids)

        # Combine all embeddings
        embedding = word_embedding + position_embedding + token_type_embedding

        embedding = self.LayerNorm(embedding)
        embedding = self.dropout(embedding)
        return embedding



BERT Attention

自注意力機制是 Transformer 架構中最精彩的一段,將我們前面所計算好的 embeddings 通過 QKV 三個線性層,計算出每個 token 的 Query、Key、Value。

Query、Key、Value 的形狀原本會是 (batch_size, seq_len, hidden_size),但為了湊出多頭注意力機制(Multi-head Attention Mechanism, MHA),讓不同的頭(head)學習提取不同的特徵,我們還需要把 hidden_size 切分成 head_num 個部份,所以也在這一步驟轉換成 (batch_size, head_num, seq_len, hidden_size / head_num) 的形狀。

之後就很單純了,Query 與轉置後的 Key 計算內積後再除以『根號 hidden_size』(公式通常寫根號d),之後再把 attention_mask 的元素遮蔽掉(通常是填充讓不同句子長度一致的填充符號,故不希望其參與計算),通常是加上一個很大的負值再把計算結果通過 Softmax,其填充符號的元素自然為 0。

之後,再次與 Value 計算內積,將形狀從 (batch_size, head_num, seq_len, hidden_size / head_num) 再次轉回 (batch_size, seq_len, hidden_size),通過一層線性層和 LayerNorm,就完成了自注意力機制的計算了。

class BertSelfOutput(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(hidden_states)
        outputs = self.dropout(outputs)
        outputs = self.LayerNorm(outputs + inputs)
        return outputs


class BertAttention(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.self = BertSelfAttention(config=config)
        self.output = BertSelfOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.self(inputs, attention_mask=attention_mask)
        outputs = self.output(hidden_states=hidden_states, inputs=inputs)
        return outputs



BERT Feed-Forward Neural Network

在計算完自注意力機制的結果後,下面就是前饋神經網路的區域了。需要記得的是,自注意力機制和前饋神經網路都會按照 BertConfig 的設定參數堆疊數層,並不會只有一次。比方說在 bert-tiny 中,這兩塊模組便會堆疊 2 次。

class BertIntermediate(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
        )
        self.intermediate_act_fn = torch.nn.GELU()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(inputs)
        outputs = self.intermediate_act_fn(outputs)
        return outputs
    
class BertOutput(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        outputs = self.LayerNorm(hidden_states + inputs)
        return outputs



BertLayer

自注意力機制和前饋神經網路的部份就組合成 BertLayer,會在後方的 BertEncoder 類別中堆疊數次。

class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs



BertEncoder

在 BertEncoder 的部份,自注意力機制和前饋神經網路所組合成的 BertLayer 堆疊了

class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs



BertPooler

最後再介紹一個 BertPooler 類別,在 BERT 當中基本上就是取其 CLS (第一個 token)作為輸出,可以視為 BERT 在一系列的計算中所提取的資訊。

class BertPooler(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.activation = torch.nn.Tanh()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        cls_token_tensor = inputs[:, 0]
        pooler_outputs = self.dense(cls_token_tensor)
        pooler_outputs = self.activation(pooler_outputs)
        return pooler_outputs



BertModel

這裡我除了把 BertModel 的所有組件整合起來以外,也實現了跟 HuggingFace transformers 套件一樣的 .from_pretrained() 接口,可支援 HuggingFace 上的 BERT 系列模型。

class BertModel(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = BertEncoder(config=config)
        self.pooler = BertPooler(config=config)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
        embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        encoded = self.encoder(embedding, attention_mask=attention_mask)
        pooler_output = self.pooler(encoded.last_hidden_state)
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=encoded.last_hidden_state,
            pooler_output=pooler_output,
        )
    
    @staticmethod
    def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
        """Load pretrained weights from HuggingFace into model.
        
        Args:
            pretrained_model_name_or_path: One of
                * "prajjwal1/bert-tiny"
                ...

        Returns:
            model: BertModel model with weights loaded
        """

        def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
            resolved_archive_file = cached_file(
                path_or_repo_id=path_or_repo_id,
                filename=WEIGHTS_NAME,
            )
            return torch.load(resolved_archive_file, weights_only=True)

        # Load config
        config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

        # Load weights
        new_state_dict = {}
        state_dict = load_state_dict_hf(pretrained_model_name_or_path)

        for key in state_dict:
            if "cls" in key or "position_ids" in key:
                continue

            new_key = key.replace("bert.", "")

            if "LayerNorm.gamma" in key:
                new_key = new_key.replace("gamma", "weight")
            elif "LayerNorm.beta" in key:
                new_key = new_key.replace("beta", "bias")
            
            new_state_dict[new_key] = state_dict[key]

        # Load model
        model = BertModel(config=config)
        model.load_state_dict(new_state_dict)

        return model



完整程式碼

from typing import OrderedDict, Optional, List, Tuple

from dataclasses import dataclass
import json
import math

import torch
from transformers import AutoModel, BertTokenizer, BertConfig
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file


@dataclass
class BertConfig:
    architectures: Optional[List[str]] = None
    attention_probs_dropout_prob: float = 0.1
    gradient_checkpointing: Optional[bool] = False
    classifier_dropout: Optional[float] = None
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    hidden_size: int = 128
    initializer_range: float = 0.02
    intermediate_size: int = 512
    layer_norm_eps: float = 1e-12
    max_position_embeddings: int = 512
    model_type: str = "bert"
    num_attention_heads: int = 2
    num_hidden_layers: int = 2
    pad_token_id: int = 0
    position_embedding_type: str = "absolute"
    transformers_version: str = "4.36.2"
    type_vocab_size: int = 2
    use_cache: bool = True
    vocab_size: int = 30522

    @staticmethod
    def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
        resolved_archive_file = cached_file(
            path_or_repo_id=pretrained_model_name_or_path,
            filename=CONFIG_NAME,
            _raise_exceptions_for_missing_entries=False,
        )
        config_content = json.load(open(resolved_archive_file))
        return BertConfig(**config_content)


@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


class BertEmbeddings(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

        word_embedding = self.word_embeddings(input_ids)
        token_type_embedding = self.token_type_embeddings(token_type_ids)
        position_embedding = self.position_embeddings(position_ids)

        # Combine all embeddings
        embedding = word_embedding + position_embedding + token_type_embedding

        embedding = self.LayerNorm(embedding)
        embedding = self.dropout(embedding)
        return embedding
    

class BertSelfAttention(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.head_size * self.num_attention_heads

        self.query = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.key = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.value = torch.nn.Linear(config.hidden_size, self.all_head_size)
        
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        q = self.transpose_for_scores(self.query(inputs))
        k = self.transpose_for_scores(self.key(inputs))
        v = self.transpose_for_scores(self.value(inputs))

        # Attetnion score
        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.head_size)

        # If `attention_mask` is None
        if attention_mask is not None:
            extended_attention_mask = attention_mask[:, None, None, :].to(dtype=inputs.dtype)  # fp16 compatibility
            extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(inputs.dtype).min
            attention_scores = attention_scores + extended_attention_mask

        attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        attention_outputs = torch.matmul(attention_probs, v)

        # Merge the head weights, (batch_size, seq_len, head_size)
        attention_outputs = attention_outputs.permute(0, 2, 1, 3).contiguous()
        attention_outputs = attention_outputs.view(*attention_outputs.shape[:2], -1)
        return attention_outputs
    

class BertSelfOutput(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(hidden_states)
        outputs = self.dropout(outputs)
        outputs = self.LayerNorm(outputs + inputs)
        return outputs


class BertAttention(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.self = BertSelfAttention(config=config)
        self.output = BertSelfOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.self(inputs, attention_mask=attention_mask)
        outputs = self.output(hidden_states=hidden_states, inputs=inputs)
        return outputs
    

class BertIntermediate(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
        )
        self.intermediate_act_fn = torch.nn.GELU()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(inputs)
        outputs = self.intermediate_act_fn(outputs)
        return outputs
    
class BertOutput(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        outputs = self.LayerNorm(hidden_states + inputs)
        return outputs


class BertPooler(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.activation = torch.nn.Tanh()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        cls_token_tensor = inputs[:, 0]
        pooler_outputs = self.dense(cls_token_tensor)
        pooler_outputs = self.activation(pooler_outputs)
        return pooler_outputs


class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs


class BertEncoder(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.layer = torch.nn.ModuleList(
            [BertLayer(config=config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPastAndCrossAttentions:
        for layer_module in self.layer:
            inputs = layer_module(inputs=inputs, attention_mask=attention_mask)

        return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=inputs)
    

class BertModel(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = BertEncoder(config=config)
        self.pooler = BertPooler(config=config)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
        embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        encoded = self.encoder(embedding, attention_mask=attention_mask)
        pooler_output = self.pooler(encoded.last_hidden_state)
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=encoded.last_hidden_state,
            pooler_output=pooler_output,
        )
    
    @staticmethod
    def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
        """Load pretrained weights from HuggingFace into model.
        
        Args:
            pretrained_model_name_or_path: One of
                * "prajjwal1/bert-tiny"
                ...

        Returns:
            model: BertModel model with weights loaded
        """

        def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
            resolved_archive_file = cached_file(
                path_or_repo_id=path_or_repo_id,
                filename=WEIGHTS_NAME,
            )
            return torch.load(resolved_archive_file, weights_only=True)

        # Load config
        config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

        # Load weights
        new_state_dict = {}
        state_dict = load_state_dict_hf(pretrained_model_name_or_path)

        for key in state_dict:
            if "cls" in key or "position_ids" in key:
                continue

            new_key = key.replace("bert.", "")

            if "LayerNorm.gamma" in key:
                new_key = new_key.replace("gamma", "weight")
            elif "LayerNorm.beta" in key:
                new_key = new_key.replace("beta", "bias")
            
            new_state_dict[new_key] = state_dict[key]

        # Load model
        model = BertModel(config=config)
        model.load_state_dict(new_state_dict)

        return model


if __name__ == "__main__":
    # Init
    # pretrained_model_name_or_path = "prajjwal1/bert-tiny"
    pretrained_model_name_or_path = "google-bert/bert-base-uncased"

    # Tokenizer
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

    # Model
    my_model = BertModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()
    hf_model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()

    # Data
    sentences = [
        ("Today is a nice day", "I want to go to play"),
        ("Hello", "Nice to meet you too")
    ]

    inputs = tokenizer.batch_encode_plus(sentences, add_special_tokens=True, padding=True, return_tensors="pt")

    print(my_model(**inputs).last_hidden_state[0][0][0])
    print(hf_model(**inputs).last_hidden_state[0][0][0])
    print(my_model(**inputs).last_hidden_state[0][0][0] == hf_model(**inputs).last_hidden_state[0][0][0])
    print(torch.allclose(my_model(**inputs).last_hidden_state, hf_model(**inputs).last_hidden_state, atol=1e-5))


Output:

tensor(-0.0092, grad_fn=<SelectBackward0>)
tensor(-0.0092, grad_fn=<SelectBackward0>)
tensor(False)
True


這裡我還多做了一次比較,確認我跟 transformers 中的 BERT 實現,在讀取同樣的權重時,其輸出的精度到小數點第五位都仍然一致。


References


Read More

Leave a Reply