Skip to content

[Paper Reading] Lifting the Curse of Multilinguality by Pre-training Modular Transformers

Last Updated on 2024-08-19 by Clay

Cross-lingual Modular (X-Mod) is an interesting language model architecture that modularizes the parameters for different languages as Module Units, allowing the model to use separate parameters when fine-tuning for a new language, thereby (comparatively) avoiding the problem of catastrophic forgetting.

The main reason I looked into this paper is that I recently came across ColBERT-XM: A Modular Multi-Vector Representation Model for Zero-Shot Multilingual Information Retrieval, and a colleague tested this ColBERT model on our internal Chinese dataset, surprisingly finding that it outperformed other models by about 10%.

After quickly reading the paper, I realized that the core issue it aims to address is the so-called Curse of Multilinguality, where a language model's performance in previously learned languages starts to decline as it attempts to cover more languages. This new architecture was proposed to solve this problem.

Let's explore its differences from other language model architectures!


Model Architecture

The basic architecture of the model remains a Transformer, with the usual embedding layer, attention mechanism, and feed-forward layer all included. The biggest difference, however, is the addition of a Modular Layer after these components. LayerNorm is also applied both before and after the Modular Layer, and a residual connection is left to carry the information from before entering the Modular Layer to the output. This likely helps in better gradient flow and more stable training — I think it's a standard addition that must have been kept after experimental testing.

We can observe that the additional modular layers range from Language 1 to Language n, meaning that there will be a different module for each language. During inference, only one language module is activated at a time.

From HuggingFace's model card, we can see that the language to be activated needs to be explicitly specified:

from transformers import XmodModel

model = XmodModel.from_pretrained("facebook/xmod-base")
model.set_default_language("en_XX")


And in the source code, when the hidden_states are passed in, only the adapter corresponding to the activated language will process the hidden_states:

...
    def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor):
        # Process subsequent samples with the same lang_id in parallel
        lang_ids, lang_lengths = torch.unique_consecutive(lang_ids, return_counts=True)

        if not self.ln_before_adapter:
            residual = hidden_states

        if self.adapter_layer_norm is not None:
            hidden_states = self.adapter_layer_norm(hidden_states)
        elif self.adapter_reuse_layer_norm:
            hidden_states = self.LayerNorm(hidden_states)

        if self.ln_before_adapter:
            residual = hidden_states

        split_hidden_states = torch.split(hidden_states, lang_lengths.tolist(), 0)
        lang_wise_outputs = []

        for i, (lang_id, split_hidden_state) in enumerate(zip(lang_ids, split_hidden_states)):
            lang = list(self.adapter_modules.keys())[int(lang_id.item())]
            lang_wise_outputs.append(self.adapter_modules[lang](split_hidden_state))
        hidden_states = torch.cat(lang_wise_outputs, 0)

        hidden_states = self.dropout(hidden_states)
        hidden_states += residual
        return hidden_states


However, due to this, the adapters for different languages are separated. As shown in the X-MOD part of the image below (b), when we need to add a new language, we can simply add a new adapter (and expand the vocabulary in the Embedding Layer, which will be mentioned in the next section).


Model Training Approach

The training of the X-Mod model can be understood in three different stages:

  1. Pre-training procedure
  2. Extending to new languages
  3. Fine-tuning on downstream tasks


1. Pre-training procedure

The pre-training of the X-Mod model follows the classic Masked Language Modeling (MLM) approach, where tokens are randomly masked according to a certain corruption ratio, and during training, the model attempts to reconstruct the correct token based on the context, thereby enhancing its understanding of the text.

Cited from https://arxiv.org/pdf/1810.04805

In this task, various languages' data are used, and the data from each language will train the shared embedding layer, attention mechanism, and feed-forward layer. During the modular layer for different languages, the data is routed to the corresponding language module.


2. Extending to new languages

Due to the modular design, adding a new language module layer does not significantly impact the previously learned languages (which is the core problem this research aims to solve). In addition to adding a dedicated language module layer, the size of the embedding layer and the tokenizer also need to be updated according to the new language's vocabulary.

At this stage, MLM training is still conducted, but only the embedding layer and language module layer are trained, while the other parameters are frozen.

After this step, we obtain a language model with the ability to understand the new language. However, intuitively, although the model has acquired the ability to understand the new language, why not directly train a language model specifically for the new language? My understanding is that in the next step of fine-tuning for downstream tasks, the model can further acquire cross-linguistic understanding.


3. Fine-tuning on downstream tasks

This step primarily aims to enable the language model to understand cross-lingual text (such as mixed Chinese and English data), and the approach is straightforward.

We have just obtained new language module layers and an embedding layer, but now these parameters are frozen and excluded from training. This time, we train the shared weights of the different language module layers, namely the attention mechanism and feed-forward layer. However, during the model inference process, the hidden state will be processed by the new language module layer.

(2024/08/19 Update: My colleague reminded me that in the third stage, according to the paper, it does not necessarily use a multilingual dataset; it is also possible to only use a single-language dataset, as the original text states "source language data.")


Conclusion

I'll skip the detailed evaluations since I've already confirmed through practical data in my work that the ColBERT-XM performs better than other open-source models. So, this reading focuses on understanding its internal architecture and the problem it aims to solve.

If we think more broadly, could this modular architecture be applied beyond just languages? After all, adapters were originally developed for fine-tuning downstream tasks — could modular layers be responsible for different downstream tasks to avoid the catastrophic forgetting mentioned earlier?

Moreover, many current LLMs can achieve better results through Chain-of-Thought (CoT) reasoning. Could the model use the modular layer to think and generate responses simultaneously? Perhaps related papers already exist on this topic.


References


Read More

Leave a Reply