Skip to content

[PyTorch] BERT Architecture Implementation Note

Last Updated on 2024-09-07 by Clay

Introduction

My advisor used to tell me, “Don’t just use other people’s libraries; you have to write your own to truly understand.” Back then, I didn’t have much time to implement various technologies I was interested in since I was fully occupied with my dissertation. However, I often recall his earnest advice even now, and it prompted me to finally attempt the implementation of BERT, a classic encoder-only transformer model.

In part, this is because I recently discussed with a colleague the significance of the PAD special symbol in the forward propagation of the Transformer architecture and explored the details of feature extraction through the Multi-head Attention Mechanism. These discussions reminded me that there are still many nuances and variations in the Transformer architecture that I am not fully familiar with. Revisiting my understanding of this classic model is meaningful and necessary.

Thus, over the course of about three weeks, I resisted the temptation to look at how HuggingFace’s transformers library implemented it and instead focused on the final output results. I referenced various sources and diagrams online, and after intermittent effort, I completed a PyTorch implementation of the BERT model.

Of course, I must admit that while my BERT implementation approximates the HuggingFace transformers’ BERT model to five decimal places in terms of output precision, from the sixth decimal place onward, discrepancies begin to appear. Furthermore, some parts of my implementation are nearly identical to transformers’ bert_modeling.py, likely because the examples I consulted were sourced directly from it.

Below, I will briefly introduce my implementation. If there are any errors, please feel free to point them out. Thank you! Finally, my source code is open-sourced on GitHub, so feel free to check it out.


Introduction to the BERT Model

The BERT model was first introduced by Google in their 2018 paper “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” After its release, BERT quickly set new performance records across nearly all mainstream NLP benchmarks, a testament to its power and effectiveness.

As a result, a wave of BERT applications emerged, with various domains creating their own feature extraction and classification models using BERT. This led to an explosion of BERT variants, much like the current trend with GPT (as of February 2024).

BERT follows the classic two-stage approach of pretraining and fine-tuning. The first stage uses two unsupervised learning tasks to extract features from text (predicting masked tokens and the Next Sentence Prediction, NSP, task). After pretraining, BERT has a solid ability to extract features, and developers can fine-tune it for downstream tasks as needed.

Thanks to the extensive pretraining on large datasets, BERT performs well on downstream tasks without requiring massive amounts of data to adapt and converge effectively.

BERT’s architecture is also straightforward by today’s standards: a combination of self-attention mechanisms and feed-forward neural networks.

Now, let’s implement the actual BERT architecture! The goal is to read in the pretrained model weights from HuggingFace and ensure that the outputs match those from the transformers’ BertModel with a reasonable level of precision.


BERT Implementation

While implementing the various modules, you can refer back to the architecture diagram I provided above.

Importing All Necessary Libraries

In this implementation, I’ve defined two classes: BertConfig and BertModel. I have not implemented BertTokenizer, so in my test code, I’ll use the transformers library’s BertTokenizer for tokenization.

from typing import OrderedDict, Optional, List, Tuple

from dataclasses import dataclass
import json
import math

import torch
from transformers import AutoModel, BertTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file



BertConfig

The properties of the following configuration class are based on both bert-tiny and bert-base. It is designed to be compatible with the from_pretrained() method from transformers. However, if some BERT models on HuggingFace have more parameters, initialization might fail due to additional attributes not accounted for here.

But for the basic implementation (as a practice exercise) and not for production use, I haven’t added error handling.

@dataclass
class BertConfig:
    architectures: Optional[List[str]] = None
    attention_probs_dropout_prob: float = 0.1
    gradient_checkpointing: Optional[bool] = False
    classifier_dropout: Optional[float] = None
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    hidden_size: int = 128
    initializer_range: float = 0.02
    intermediate_size: int = 512
    layer_norm_eps: float = 1e-12
    max_position_embeddings: int = 512
    model_type: str = "bert"
    num_attention_heads: int = 2
    num_hidden_layers: int = 2
    pad_token_id: int = 0
    position_embedding_type: str = "absolute"
    transformers_version: str = "4.36.2"
    type_vocab_size: int = 2
    use_cache: bool = True
    vocab_size: int = 30522

    @staticmethod
    def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
        resolved_archive_file = cached_file(
            path_or_repo_id=pretrained_model_name_or_path,
            filename=CONFIG_NAME,
            _raise_exceptions_for_missing_entries=False,
        )
        config_content = json.load(open(resolved_archive_file))
        return BertConfig(**config_content)



Output Sample Classes

These sections don’t contribute directly to the BERT model’s implementation but are added to align the output format with the transformers library’s BertModel.

BaseModelOutputWithPastAndCrossAttentions represents an intermediate output format, while BaseModelOutputWithPoolingAndCrossAttentions represents the final model output format.

@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None



BertEmbedding

Finally, we’ve arrived at the first layer of the BERT architecture. After we input our data into the BERT model, the first layer it encounters is the Embedding Layer.

The purpose of the Embedding Layer is to map specific IDs to corresponding vectors (also known as embeddings). These embeddings are meaningful and, being vectors, are necessarily numerical data.

For example, if we want to obtain the vector for the phrase “Today the weather is nice,” we might tokenize it into tokens and convert it to its corresponding ID sequence (just for illustration purposes here, typically each Chinese character would be 1-3 tokens):

["Today", "the", "weather", "is", "nice"] =>
[1, 2, 2, 3, 4 ,4]


Next, through lookup tables, we find the vector corresponding to each ID. This is the work done by the embedding layer.

In BERT, apart from the input word sequence, there’s also position information (0 to seq_len) and token types (token type embeddings).

Position information is provided by the position_embedding layer, which in BERT is a trainable neural layer.

Token type embeddings come from BERT’s pretraining task of Next Sentence Prediction (NSP).

In NSP, BERT needs to determine whether two sentences, with tokens labeled as 0 and 1 respectively, are contextually related. Hence, we need different token type embeddings to distinguish them.

For cross-encoder tasks, we would use this embedding layer again to differentiate sentences. For single-input tasks, it’s fine to set all token type IDs to 0.

Finally, the three different embeddings are summed, passed through a Layer Normalization (Layer Norm), and then through dropout.

class BertEmbeddings(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

        word_embedding = self.word_embeddings(input_ids)
        token_type_embedding = self.token_type_embeddings(token_type_ids)
        position_embedding = self.position_embeddings(position_ids)

        # Combine all embeddings
        embedding = word_embedding + position_embedding + token_type_embedding

        embedding = self.LayerNorm(embedding)
        embedding = self.dropout(embedding)
        return embedding



BERT Attention

The self-attention mechanism is the highlight of the Transformer architecture. After computing the embeddings, we pass them through three linear layers to compute the Query, Key, and Value for each token.

The shapes of Query, Key, and Value are initially (batch_size, seq_len, hidden_size). To implement the Multi-head Attention Mechanism (MHA) that allows each head to learn different features, we further split hidden_size into head_num parts, reshaping it to (batch_size, head_num, seq_len, hidden_size / head_num).

Afterward, it’s fairly straightforward: We compute the dot product between the Query and the transposed Key, divide by the square root of the hidden size (often written as sqrt(d) in formulas), and then apply the attention mask (usually used to mask padding tokens so they don’t participate in the computation). The padding tokens are often masked with a large negative value, and after passing through Softmax, their contributions become 0.

Next, we compute the dot product between the result and the Value, and reshape the output back to (batch_size, seq_len, hidden_size). Finally, we pass it through a linear layer and LayerNorm, completing the self-attention mechanism.

class BertSelfOutput(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(hidden_states)
        outputs = self.dropout(outputs)
        outputs = self.LayerNorm(outputs + inputs)
        return outputs


class BertAttention(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.self = BertSelfAttention(config=config)
        self.output = BertSelfOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.self(inputs, attention_mask=attention_mask)
        outputs = self.output(hidden_states=hidden_states, inputs=inputs)
        return outputs



BERT Feed-Forward Neural Network

After computing the result of the self-attention mechanism, we move on to the feed-forward neural network (FFN). Keep in mind that both the self-attention mechanism and the FFN are stacked according to the settings in BertConfig. For example, in bert-tiny, these modules are stacked twice.

class BertIntermediate(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
        )
        self.intermediate_act_fn = torch.nn.GELU()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(inputs)
        outputs = self.intermediate_act_fn(outputs)
        return outputs
    
class BertOutput(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        outputs = self.LayerNorm(hidden_states + inputs)
        return outputs



BertLayer

The self-attention mechanism and FFN modules together form a BertLayer, which is then stacked multiple times in the BertEncoder class.

class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs



BertEncoder

In the BertEncoder section, the BertLayer composed of the self-attention mechanism and FFN modules is stacked multiple times.

class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs



BertPooler

Finally, we introduce the BertPooler class. In BERT, this essentially retrieves the CLS token (the first token) as the output, which can be seen as the distilled information BERT has extracted from the entire sequence.

class BertPooler(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.activation = torch.nn.Tanh()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        cls_token_tensor = inputs[:, 0]
        pooler_outputs = self.dense(cls_token_tensor)
        pooler_outputs = self.activation(pooler_outputs)
        return pooler_outputs



BertModel

Here, I’ve not only integrated all the components of BertModel but also implemented the .from_pretrained() interface, similar to HuggingFace transformers, so it supports the BERT models available on HuggingFace.

class BertModel(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = BertEncoder(config=config)
        self.pooler = BertPooler(config=config)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
        embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        encoded = self.encoder(embedding, attention_mask=attention_mask)
        pooler_output = self.pooler(encoded.last_hidden_state)
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=encoded.last_hidden_state,
            pooler_output=pooler_output,
        )
    
    @staticmethod
    def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
        """Load pretrained weights from HuggingFace into model.
        
        Args:
            pretrained_model_name_or_path: One of
                * "prajjwal1/bert-tiny"
                ...

        Returns:
            model: BertModel model with weights loaded
        """

        def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
            resolved_archive_file = cached_file(
                path_or_repo_id=path_or_repo_id,
                filename=WEIGHTS_NAME,
            )
            return torch.load(resolved_archive_file, weights_only=True)

        # Load config
        config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

        # Load weights
        new_state_dict = {}
        state_dict = load_state_dict_hf(pretrained_model_name_or_path)

        for key in state_dict:
            if "cls" in key or "position_ids" in key:
                continue

            new_key = key.replace("bert.", "")

            if "LayerNorm.gamma" in key:
                new_key = new_key.replace("gamma", "weight")
            elif "LayerNorm.beta" in key:
                new_key = new_key.replace("beta", "bias")
            
            new_state_dict[new_key] = state_dict[key]

        # Load model
        model = BertModel(config=config)
        model.load_state_dict(new_state_dict)

        return model



Complete Code

from typing import OrderedDict, Optional, List, Tuple

from dataclasses import dataclass
import json
import math

import torch
from transformers import AutoModel, BertTokenizer, BertConfig
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file


@dataclass
class BertConfig:
    architectures: Optional[List[str]] = None
    attention_probs_dropout_prob: float = 0.1
    gradient_checkpointing: Optional[bool] = False
    classifier_dropout: Optional[float] = None
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    hidden_size: int = 128
    initializer_range: float = 0.02
    intermediate_size: int = 512
    layer_norm_eps: float = 1e-12
    max_position_embeddings: int = 512
    model_type: str = "bert"
    num_attention_heads: int = 2
    num_hidden_layers: int = 2
    pad_token_id: int = 0
    position_embedding_type: str = "absolute"
    transformers_version: str = "4.36.2"
    type_vocab_size: int = 2
    use_cache: bool = True
    vocab_size: int = 30522

    @staticmethod
    def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
        resolved_archive_file = cached_file(
            path_or_repo_id=pretrained_model_name_or_path,
            filename=CONFIG_NAME,
            _raise_exceptions_for_missing_entries=False,
        )
        config_content = json.load(open(resolved_archive_file))
        return BertConfig(**config_content)


@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


class BertEmbeddings(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

        word_embedding = self.word_embeddings(input_ids)
        token_type_embedding = self.token_type_embeddings(token_type_ids)
        position_embedding = self.position_embeddings(position_ids)

        # Combine all embeddings
        embedding = word_embedding + position_embedding + token_type_embedding

        embedding = self.LayerNorm(embedding)
        embedding = self.dropout(embedding)
        return embedding
    

class BertSelfAttention(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.head_size * self.num_attention_heads

        self.query = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.key = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.value = torch.nn.Linear(config.hidden_size, self.all_head_size)
        
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        q = self.transpose_for_scores(self.query(inputs))
        k = self.transpose_for_scores(self.key(inputs))
        v = self.transpose_for_scores(self.value(inputs))

        # Attetnion score
        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.head_size)

        # If attention_mask is None
        if attention_mask is not None:
            extended_attention_mask = attention_mask[:, None, None, :].to(dtype=inputs.dtype)  # fp16 compatibility
            extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(inputs.dtype).min
            attention_scores = attention_scores + extended_attention_mask

        attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        attention_outputs = torch.matmul(attention_probs, v)

        # Merge the head weights, (batch_size, seq_len, head_size)
        attention_outputs = attention_outputs.permute(0, 2, 1, 3).contiguous()
        attention_outputs = attention_outputs.view(*attention_outputs.shape[:2], -1)
        return attention_outputs
    

class BertSelfOutput(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(hidden_states)
        outputs = self.dropout(outputs)
        outputs = self.LayerNorm(outputs + inputs)
        return outputs


class BertAttention(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.self = BertSelfAttention(config=config)
        self.output = BertSelfOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.self(inputs, attention_mask=attention_mask)
        outputs = self.output(hidden_states=hidden_states, inputs=inputs)
        return outputs
    

class BertIntermediate(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
        )
        self.intermediate_act_fn = torch.nn.GELU()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.dense(inputs)
        outputs = self.intermediate_act_fn(outputs)
        return outputs
    
class BertOutput(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
        )
        self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
    
    def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        outputs = self.LayerNorm(hidden_states + inputs)
        return outputs


class BertPooler(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.dense = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
        )
        self.activation = torch.nn.Tanh()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        cls_token_tensor = inputs[:, 0]
        pooler_outputs = self.dense(cls_token_tensor)
        pooler_outputs = self.activation(pooler_outputs)
        return pooler_outputs


class BertLayer(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
        attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
        intermediate_outputs = self.intermediate(inputs=attention_outputs)
        outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
        return outputs


class BertEncoder(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.layer = torch.nn.ModuleList(
            [BertLayer(config=config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPastAndCrossAttentions:
        for layer_module in self.layer:
            inputs = layer_module(inputs=inputs, attention_mask=attention_mask)

        return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=inputs)
    

class BertModel(torch.nn.Module):
    def __init__(self, config: BertConfig) -> None:
        super().__init__()
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = BertEncoder(config=config)
        self.pooler = BertPooler(config=config)

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
        embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        encoded = self.encoder(embedding, attention_mask=attention_mask)
        pooler_output = self.pooler(encoded.last_hidden_state)
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=encoded.last_hidden_state,
            pooler_output=pooler_output,
        )
    
    @staticmethod
    def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
        """Load pretrained weights from HuggingFace into model.
        
        Args:
            pretrained_model_name_or_path: One of
                * "prajjwal1/bert-tiny"
                ...

        Returns:
            model: BertModel model with weights loaded
        """

        def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
            resolved_archive_file = cached_file(
                path_or_repo_id=path_or_repo_id,
                filename=WEIGHTS_NAME,
            )
            return torch.load(resolved_archive_file, weights_only=True)

        # Load config
        config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

        # Load weights
        new_state_dict = {}
        state_dict = load_state_dict_hf(pretrained_model_name_or_path)

        for key in state_dict:
            if "cls" in key or "position_ids" in key:
                continue

            new_key = key.replace("bert.", "")

            if "LayerNorm.gamma" in key:
                new_key = new_key.replace("gamma", "weight")
            elif "LayerNorm.beta" in key:
                new_key = new_key.replace("beta", "bias")
            
            new_state_dict[new_key] = state_dict[key]

        # Load model
        model = BertModel(config=config)
        model.load_state_dict(new_state_dict)

        return model


if __name__ == "__main__":
    # Init
    # pretrained_model_name_or_path = "prajjwal1/bert-tiny"
    pretrained_model_name_or_path = "google-bert/bert-base-uncased"

    # Tokenizer
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

    # Model
    my_model = BertModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()
    hf_model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()

    # Data
    sentences = [
        ("Today is a nice day", "I want to go to play"),
        ("Hello", "Nice to meet you too")
    ]

    inputs = tokenizer.batch_encode_plus(sentences, add_special_tokens=True, padding=True, return_tensors="pt")

    print(my_model(**inputs).last_hidden_state[0][0][0])
    print(hf_model(**inputs).last_hidden_state[0][0][0])
    print(my_model(**inputs).last_hidden_state[0][0][0] == hf_model(**inputs).last_hidden_state[0][0][0])
    print(torch.allclose(my_model(**inputs).last_hidden_state, hf_model(**inputs).last_hidden_state, atol=1e-5))


Output:

tensor(-0.0092, grad_fn=<SelectBackward0>)
tensor(-0.0092, grad_fn=<SelectBackward0>)
tensor(False)
True


Here, I also performed an additional comparison to confirm that when reading the same weights, my BERT implementation produces outputs consistent with the transformers library’s BERT implementation up to five decimal places.


References


Read More

Leave a Reply