Skip to content

[Paper Reading] Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection

Last Updated on 2024-07-25 by Clay

Introduction

RAG-based LLM is a well-known architecture in current usage of Large Language Models (LLM). It involves "retrieval" to provide the model with prior knowledge that it lacks during training, enabling the model to answer questions in the context of specific information.

For many businesses or merchants, the "processing of text tasks with large language models" is a practical and utilizable technology. For example, in the year 2023, countless companies aspire to create their own customer service chatbots that understand internal company data using models like ChatGPT, Llama, Mistral, and more.

In many scenarios, retraining a dedicated customer service AI involves significant resource costs. Therefore, training a model that understands internal company documents and bases its responses on them is very helpful. The system that allows this, where a user's query Q is used to retrieve the relevant top-K documents through a retrieval system R and then a large language model generates responses based on these top-K documents for question Q, is known as the RAG (Retrieval-Augmented Generation) architecture.

RAG Architecture

However, the "hallucination problem" of large language models has not been resolved in the past year. Even with the RAG approach, where relevant information is retrieved for the model to answer user questions, there are still potential issues:

  1. The retrieved information (context) heavily relies on the embedding model, which may result in retrieving unrelated information to the user's query, introducing noise into the model.
  2. Retrieving information for all texts is an inefficient behavior, and providing more context to the model increases the computational burden of self-regressive models exponentially.

Self-RAG partially addresses the model's hallucination problem by introducing a "reflection token." This token allows the model to propose retrieval when it believes it needs to retrieve information. It scores and continues its response based on the retrieved paragraphs, and finally generates scores for its own response. The generation of different paragraphs retrieved is done in parallel, so the scores from the reflection token are used to select the generated paragraphs.

You can refer to the above diagram for a comparison of the process between Self-RAG and the typical RAG approach.


Training Of Self-RAG

The training process involves two components: the Critic Model (C) and the Generator Model (M).

The Critic Model (C) is specifically trained to generate reflection tokens, and its primary task is to produce evaluations for the retrieved information and the generated responses. On the other hand, the Generator Model (M) is the model used for inference, responsible for generating responses. It continues training based on the data generated by the Critic Model.

1. Training of the Critic Model (C)

  • C's task is to generate reflection tokens, which are used to evaluate the retrieved passages and assign scores to the generated responses.
  • During training, C utilizes a dataset annotated by GPT-4, which primarily consists of passages retrieved by the retrieval system (R) and the corresponding reflection tokens that C should produce.
  • Finally, the C model is employed to update the training data, and the data marked by C is subsequently used for training the generation model (M). This process ensures that in real-world inference tasks, the generation model M is capable of generating these reflection tokens on its own.

According to the diagram, reflection tokens come in various categories:

  1. Retrieve: yes/no/continue
  2. IsREL: relevant/irrelevant
  3. IsSUP: fully supported/partially supported/not supported
  4. IsUSE: 5/4/3/2/1


2. Training of the Generation Model (M)

M continues training from the data labeled by the C model, and it can generate text and reflection tokens on its own.

You can think of it as first training a specialized model to learn these reflection tasks and then letting the final generative model learn this process. It can also be understood as a way to reduce the cost of labeling all the reflection tokens with GPT-4. Instead, knowledge specific to this task is distilled into a smaller language model, which then handles all the task labeling.


Inference Of Self-RAG

The key difference between Self-RAG and regular RAG lies in the following process:

  1. The model actively proposes retrieval tokens for information retrieval.
  2. Parallel generation of segments based on the retrieved information.
  3. Checking the generated results for relevance and selecting the most relevant segments.
  4. Repeating the retrieval process...
  5. And so on, until the generation is complete.

Of course, the model may not perform any retrieval throughout and rely solely on its own inherent knowledge to answer.

To be more specific, it can be divided into several stages as follows:

Retrieval Stage

The retrieval system R searches for K relevant text paragraphs based on the current input x and the previous generation y (before time t). These paragraphs will serve as the basis for the next step of generation by the model M.


Parallel Generation Stage

The generation model M processes each retrieved paragraph d and generates different continuation candidates for each d. This process generates K different generation candidates, each based on a different retrieved paragraph.


Determining the Best Generation Stage

Determining the Best Generation Stage:

The classic beam search decoding algorithm is used to evaluate paragraphs at each time step t, selecting the top B-ranked continuations from the K candidates.

These candidates are evaluated and ranked based on evaluation metrics such as relevance (ISREL), supportiveness (ISSUP), and usefulness (ISUSE) derived from the reflection tokens in the generated paragraphs.

Ultimately, the highest-scoring continuation is chosen as the best output for that timestamp. If there is a tie in scores, pre-defined methods such as random selection, specific length selection, merging tied answers, embedding calculations, etc., are used to break the tie.

This process ensures that the generated responses are not only relevant to the retrieved information but also take into account the generation model's own evaluation of these retrieved paragraphs. This way, the Self-RAG G model can generate higher-quality responses while maintaining factual accuracy and information richness.


Results

In terms of overall performance, Self-RAG significantly outperforms many other LLM models and standard RAG models across various tasks. It even surpasses ChatGPT in tasks like PubHealth and PopQA.


Impressions

The method of invoking tools using special tokens in Self-RAG reminds me of Toolformer, except in Self-RAG's tasks, it's replaced with "retrieval system" and "evaluation," with the evaluation even being performed by the model itself.

While Self-RAG claims to outperform many important models in the paper, it does come at the cost of inference time due to the multiple retrievals and tree decoding. However, with the continuous development of various quantization and inference tools, such as FlashDecoding-V4 (which I've been meaning to look into but haven't had the time), the concerns about inference time today may be completely resolved someday.


References


Read More

Leave a Reply