Last Updated on 2024-09-07 by Clay
Introduction
My advisor used to tell me, “Don't just use other people's libraries; you have to write your own to truly understand.” Back then, I didn’t have much time to implement various technologies I was interested in since I was fully occupied with my dissertation. However, I often recall his earnest advice even now, and it prompted me to finally attempt the implementation of BERT, a classic encoder-only transformer model.
In part, this is because I recently discussed with a colleague the significance of the PAD special symbol in the forward propagation of the Transformer architecture and explored the details of feature extraction through the Multi-head Attention Mechanism. These discussions reminded me that there are still many nuances and variations in the Transformer architecture that I am not fully familiar with. Revisiting my understanding of this classic model is meaningful and necessary.
Thus, over the course of about three weeks, I resisted the temptation to look at how HuggingFace's transformers library implemented it and instead focused on the final output results. I referenced various sources and diagrams online, and after intermittent effort, I completed a PyTorch implementation of the BERT model.
Of course, I must admit that while my BERT implementation approximates the HuggingFace transformers' BERT model to five decimal places in terms of output precision, from the sixth decimal place onward, discrepancies begin to appear. Furthermore, some parts of my implementation are nearly identical to transformers' bert_modeling.py, likely because the examples I consulted were sourced directly from it.
Below, I will briefly introduce my implementation. If there are any errors, please feel free to point them out. Thank you! Finally, my source code is open-sourced on GitHub, so feel free to check it out.
Introduction to the BERT Model
The BERT model was first introduced by Google in their 2018 paper "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." After its release, BERT quickly set new performance records across nearly all mainstream NLP benchmarks, a testament to its power and effectiveness.
As a result, a wave of BERT applications emerged, with various domains creating their own feature extraction and classification models using BERT. This led to an explosion of BERT variants, much like the current trend with GPT (as of February 2024).
BERT follows the classic two-stage approach of pretraining and fine-tuning. The first stage uses two unsupervised learning tasks to extract features from text (predicting masked tokens and the Next Sentence Prediction, NSP, task). After pretraining, BERT has a solid ability to extract features, and developers can fine-tune it for downstream tasks as needed.
Thanks to the extensive pretraining on large datasets, BERT performs well on downstream tasks without requiring massive amounts of data to adapt and converge effectively.
BERT's architecture is also straightforward by today's standards: a combination of self-attention mechanisms and feed-forward neural networks.
Now, let’s implement the actual BERT architecture! The goal is to read in the pretrained model weights from HuggingFace and ensure that the outputs match those from the transformers' BertModel with a reasonable level of precision.
BERT Implementation
While implementing the various modules, you can refer back to the architecture diagram I provided above.
Importing All Necessary Libraries
In this implementation, I've defined two classes: BertConfig
and BertModel
. I have not implemented BertTokenizer, so in my test code, I’ll use the transformers
library's BertTokenizer
for tokenization.
from typing import OrderedDict, Optional, List, Tuple
from dataclasses import dataclass
import json
import math
import torch
from transformers import AutoModel, BertTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
BertConfig
The properties of the following configuration class are based on both bert-tiny and bert-base. It is designed to be compatible with the from_pretrained()
method from transformers
. However, if some BERT models on HuggingFace have more parameters, initialization might fail due to additional attributes not accounted for here.
But for the basic implementation (as a practice exercise) and not for production use, I haven’t added error handling.
@dataclass
class BertConfig:
architectures: Optional[List[str]] = None
attention_probs_dropout_prob: float = 0.1
gradient_checkpointing: Optional[bool] = False
classifier_dropout: Optional[float] = None
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
hidden_size: int = 128
initializer_range: float = 0.02
intermediate_size: int = 512
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_type: str = "bert"
num_attention_heads: int = 2
num_hidden_layers: int = 2
pad_token_id: int = 0
position_embedding_type: str = "absolute"
transformers_version: str = "4.36.2"
type_vocab_size: int = 2
use_cache: bool = True
vocab_size: int = 30522
@staticmethod
def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
resolved_archive_file = cached_file(
path_or_repo_id=pretrained_model_name_or_path,
filename=CONFIG_NAME,
_raise_exceptions_for_missing_entries=False,
)
config_content = json.load(open(resolved_archive_file))
return BertConfig(**config_content)
Output Sample Classes
These sections don’t contribute directly to the BERT model’s implementation but are added to align the output format with the transformers
library's BertModel
.
BaseModelOutputWithPastAndCrossAttentions
represents an intermediate output format, while BaseModelOutputWithPoolingAndCrossAttentions
represents the final model output format.
@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
BertEmbedding
Finally, we’ve arrived at the first layer of the BERT architecture. After we input our data into the BERT model, the first layer it encounters is the Embedding Layer.
The purpose of the Embedding Layer is to map specific IDs to corresponding vectors (also known as embeddings). These embeddings are meaningful and, being vectors, are necessarily numerical data.
For example, if we want to obtain the vector for the phrase “Today the weather is nice,” we might tokenize it into tokens and convert it to its corresponding ID sequence (just for illustration purposes here, typically each Chinese character would be 1-3 tokens):
["Today", "the", "weather", "is", "nice"] => [1, 2, 2, 3, 4 ,4]
Next, through lookup tables, we find the vector corresponding to each ID. This is the work done by the embedding layer.
In BERT, apart from the input word sequence, there’s also position information (0 to seq_len
) and token types (token type embeddings).
Position information is provided by the position_embedding
layer, which in BERT is a trainable neural layer.
Token type embeddings come from BERT's pretraining task of Next Sentence Prediction (NSP).
In NSP, BERT needs to determine whether two sentences, with tokens labeled as 0 and 1 respectively, are contextually related. Hence, we need different token type embeddings to distinguish them.
For cross-encoder tasks, we would use this embedding layer again to differentiate sentences. For single-input tasks, it’s fine to set all token type IDs to 0.
Finally, the three different embeddings are summed, passed through a Layer Normalization (Layer Norm), and then through dropout.
class BertEmbeddings(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor = None,
) -> torch.Tensor:
seq_length = input_ids.shape[1]
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)
word_embedding = self.word_embeddings(input_ids)
token_type_embedding = self.token_type_embeddings(token_type_ids)
position_embedding = self.position_embeddings(position_ids)
# Combine all embeddings
embedding = word_embedding + position_embedding + token_type_embedding
embedding = self.LayerNorm(embedding)
embedding = self.dropout(embedding)
return embedding
BERT Attention
The self-attention mechanism is the highlight of the Transformer architecture. After computing the embeddings, we pass them through three linear layers to compute the Query, Key, and Value for each token.
The shapes of Query, Key, and Value are initially (batch_size, seq_len, hidden_size)
. To implement the Multi-head Attention Mechanism (MHA) that allows each head to learn different features, we further split hidden_size
into head_num
parts, reshaping it to (batch_size, head_num, seq_len, hidden_size / head_num)
.
Afterward, it’s fairly straightforward: We compute the dot product between the Query and the transposed Key, divide by the square root of the hidden size (often written as sqrt(d)
in formulas), and then apply the attention mask (usually used to mask padding tokens so they don't participate in the computation). The padding tokens are often masked with a large negative value, and after passing through Softmax, their contributions become 0.
Next, we compute the dot product between the result and the Value, and reshape the output back to (batch_size, seq_len, hidden_size)
. Finally, we pass it through a linear layer and LayerNorm, completing the self-attention mechanism.
class BertSelfOutput(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.dense(hidden_states)
outputs = self.dropout(outputs)
outputs = self.LayerNorm(outputs + inputs)
return outputs
class BertAttention(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.self = BertSelfAttention(config=config)
self.output = BertSelfOutput(config=config)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
hidden_states = self.self(inputs, attention_mask=attention_mask)
outputs = self.output(hidden_states=hidden_states, inputs=inputs)
return outputs
BERT Feed-Forward Neural Network
After computing the result of the self-attention mechanism, we move on to the feed-forward neural network (FFN). Keep in mind that both the self-attention mechanism and the FFN are stacked according to the settings in BertConfig. For example, in bert-tiny, these modules are stacked twice.
class BertIntermediate(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.intermediate_size,
)
self.intermediate_act_fn = torch.nn.GELU()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.dense(inputs)
outputs = self.intermediate_act_fn(outputs)
return outputs
class BertOutput(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.intermediate_size,
out_features=config.hidden_size,
)
self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
outputs = self.LayerNorm(hidden_states + inputs)
return outputs
BertLayer
The self-attention mechanism and FFN modules together form a BertLayer
, which is then stacked multiple times in the BertEncoder
class.
class BertLayer(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.attention = BertAttention(config=config)
self.intermediate = BertIntermediate(config=config)
self.output = BertOutput(config=config)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
intermediate_outputs = self.intermediate(inputs=attention_outputs)
outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
return outputs
BertEncoder
In the BertEncoder section, the BertLayer
composed of the self-attention mechanism and FFN modules is stacked multiple times.
class BertLayer(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.attention = BertAttention(config=config)
self.intermediate = BertIntermediate(config=config)
self.output = BertOutput(config=config)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
intermediate_outputs = self.intermediate(inputs=attention_outputs)
outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
return outputs
BertPooler
Finally, we introduce the BertPooler
class. In BERT, this essentially retrieves the CLS token (the first token) as the output, which can be seen as the distilled information BERT has extracted from the entire sequence.
class BertPooler(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.activation = torch.nn.Tanh()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
cls_token_tensor = inputs[:, 0]
pooler_outputs = self.dense(cls_token_tensor)
pooler_outputs = self.activation(pooler_outputs)
return pooler_outputs
BertModel
Here, I’ve not only integrated all the components of BertModel
but also implemented the .from_pretrained()
interface, similar to HuggingFace transformers, so it supports the BERT models available on HuggingFace.
class BertModel(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.embeddings = BertEmbeddings(config=config)
self.encoder = BertEncoder(config=config)
self.pooler = BertPooler(config=config)
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
encoded = self.encoder(embedding, attention_mask=attention_mask)
pooler_output = self.pooler(encoded.last_hidden_state)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=encoded.last_hidden_state,
pooler_output=pooler_output,
)
@staticmethod
def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name_or_path: One of
* "prajjwal1/bert-tiny"
...
Returns:
model: BertModel model with weights loaded
"""
def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
resolved_archive_file = cached_file(
path_or_repo_id=path_or_repo_id,
filename=WEIGHTS_NAME,
)
return torch.load(resolved_archive_file, weights_only=True)
# Load config
config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)
# Load weights
new_state_dict = {}
state_dict = load_state_dict_hf(pretrained_model_name_or_path)
for key in state_dict:
if "cls" in key or "position_ids" in key:
continue
new_key = key.replace("bert.", "")
if "LayerNorm.gamma" in key:
new_key = new_key.replace("gamma", "weight")
elif "LayerNorm.beta" in key:
new_key = new_key.replace("beta", "bias")
new_state_dict[new_key] = state_dict[key]
# Load model
model = BertModel(config=config)
model.load_state_dict(new_state_dict)
return model
Complete Code
from typing import OrderedDict, Optional, List, Tuple
from dataclasses import dataclass
import json
import math
import torch
from transformers import AutoModel, BertTokenizer, BertConfig
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
@dataclass
class BertConfig:
architectures: Optional[List[str]] = None
attention_probs_dropout_prob: float = 0.1
gradient_checkpointing: Optional[bool] = False
classifier_dropout: Optional[float] = None
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
hidden_size: int = 128
initializer_range: float = 0.02
intermediate_size: int = 512
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_type: str = "bert"
num_attention_heads: int = 2
num_hidden_layers: int = 2
pad_token_id: int = 0
position_embedding_type: str = "absolute"
transformers_version: str = "4.36.2"
type_vocab_size: int = 2
use_cache: bool = True
vocab_size: int = 30522
@staticmethod
def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> BertConfig:
resolved_archive_file = cached_file(
path_or_repo_id=pretrained_model_name_or_path,
filename=CONFIG_NAME,
_raise_exceptions_for_missing_entries=False,
)
config_content = json.load(open(resolved_archive_file))
return BertConfig(**config_content)
@dataclass
class BaseModelOutputWithPastAndCrossAttentions:
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions:
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertEmbeddings(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = torch.nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = torch.nn.LayerNorm((config.hidden_size,), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor = None,
) -> torch.Tensor:
seq_length = input_ids.shape[1]
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)
word_embedding = self.word_embeddings(input_ids)
token_type_embedding = self.token_type_embeddings(token_type_ids)
position_embedding = self.position_embeddings(position_ids)
# Combine all embeddings
embedding = word_embedding + position_embedding + token_type_embedding
embedding = self.LayerNorm(embedding)
embedding = self.dropout(embedding)
return embedding
class BertSelfAttention(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.head_size * self.num_attention_heads
self.query = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.key = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.value = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
q = self.transpose_for_scores(self.query(inputs))
k = self.transpose_for_scores(self.key(inputs))
v = self.transpose_for_scores(self.value(inputs))
# Attetnion score
attention_scores = torch.matmul(q, k.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.head_size)
# If attention_mask is None
if attention_mask is not None:
extended_attention_mask = attention_mask[:, None, None, :].to(dtype=inputs.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(inputs.dtype).min
attention_scores = attention_scores + extended_attention_mask
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
attention_outputs = torch.matmul(attention_probs, v)
# Merge the head weights, (batch_size, seq_len, head_size)
attention_outputs = attention_outputs.permute(0, 2, 1, 3).contiguous()
attention_outputs = attention_outputs.view(*attention_outputs.shape[:2], -1)
return attention_outputs
class BertSelfOutput(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size,), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.dense(hidden_states)
outputs = self.dropout(outputs)
outputs = self.LayerNorm(outputs + inputs)
return outputs
class BertAttention(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.self = BertSelfAttention(config=config)
self.output = BertSelfOutput(config=config)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
hidden_states = self.self(inputs, attention_mask=attention_mask)
outputs = self.output(hidden_states=hidden_states, inputs=inputs)
return outputs
class BertIntermediate(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.intermediate_size,
)
self.intermediate_act_fn = torch.nn.GELU()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.dense(inputs)
outputs = self.intermediate_act_fn(outputs)
return outputs
class BertOutput(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.intermediate_size,
out_features=config.hidden_size,
)
self.LayerNorm = torch.nn.LayerNorm(normalized_shape=(config.hidden_size), eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(p=config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
outputs = self.LayerNorm(hidden_states + inputs)
return outputs
class BertPooler(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.dense = torch.nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.activation = torch.nn.Tanh()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
cls_token_tensor = inputs[:, 0]
pooler_outputs = self.dense(cls_token_tensor)
pooler_outputs = self.activation(pooler_outputs)
return pooler_outputs
class BertLayer(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.attention = BertAttention(config=config)
self.intermediate = BertIntermediate(config=config)
self.output = BertOutput(config=config)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> None:
attention_outputs = self.attention(inputs=inputs, attention_mask=attention_mask)
intermediate_outputs = self.intermediate(inputs=attention_outputs)
outputs = self.output(hidden_states=intermediate_outputs, inputs=attention_outputs)
return outputs
class BertEncoder(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.layer = torch.nn.ModuleList(
[BertLayer(config=config) for _ in range(config.num_hidden_layers)]
)
def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPastAndCrossAttentions:
for layer_module in self.layer:
inputs = layer_module(inputs=inputs, attention_mask=attention_mask)
return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=inputs)
class BertModel(torch.nn.Module):
def __init__(self, config: BertConfig) -> None:
super().__init__()
self.embeddings = BertEmbeddings(config=config)
self.encoder = BertEncoder(config=config)
self.pooler = BertPooler(config=config)
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> BaseModelOutputWithPoolingAndCrossAttentions:
embedding = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
encoded = self.encoder(embedding, attention_mask=attention_mask)
pooler_output = self.pooler(encoded.last_hidden_state)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=encoded.last_hidden_state,
pooler_output=pooler_output,
)
@staticmethod
def from_pretrained(pretrained_model_name_or_path: str) -> "BertModel":
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name_or_path: One of
* "prajjwal1/bert-tiny"
...
Returns:
model: BertModel model with weights loaded
"""
def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
resolved_archive_file = cached_file(
path_or_repo_id=path_or_repo_id,
filename=WEIGHTS_NAME,
)
return torch.load(resolved_archive_file, weights_only=True)
# Load config
config = BertConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)
# Load weights
new_state_dict = {}
state_dict = load_state_dict_hf(pretrained_model_name_or_path)
for key in state_dict:
if "cls" in key or "position_ids" in key:
continue
new_key = key.replace("bert.", "")
if "LayerNorm.gamma" in key:
new_key = new_key.replace("gamma", "weight")
elif "LayerNorm.beta" in key:
new_key = new_key.replace("beta", "bias")
new_state_dict[new_key] = state_dict[key]
# Load model
model = BertModel(config=config)
model.load_state_dict(new_state_dict)
return model
if __name__ == "__main__":
# Init
# pretrained_model_name_or_path = "prajjwal1/bert-tiny"
pretrained_model_name_or_path = "google-bert/bert-base-uncased"
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
# Model
my_model = BertModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()
hf_model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).eval()
# Data
sentences = [
("Today is a nice day", "I want to go to play"),
("Hello", "Nice to meet you too")
]
inputs = tokenizer.batch_encode_plus(sentences, add_special_tokens=True, padding=True, return_tensors="pt")
print(my_model(**inputs).last_hidden_state[0][0][0])
print(hf_model(**inputs).last_hidden_state[0][0][0])
print(my_model(**inputs).last_hidden_state[0][0][0] == hf_model(**inputs).last_hidden_state[0][0][0])
print(torch.allclose(my_model(**inputs).last_hidden_state, hf_model(**inputs).last_hidden_state, atol=1e-5))
Output:
tensor(-0.0092, grad_fn=<SelectBackward0>) tensor(-0.0092, grad_fn=<SelectBackward0>) tensor(False) True
Here, I also performed an additional comparison to confirm that when reading the same weights, my BERT implementation produces outputs consistent with the transformers library’s BERT implementation up to five decimal places.
References
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- https://github.com/huggingface/transformers