Supercharge your LLM via Retrieval Augmented Fine-tuning

Supercharge your LLM via Retrieval Augmented Fine-tuning

Introduction

Large Language Models (LLMs) have become increasingly valuable for answering questions in specialized domains, such as medical or legal documents. To enhance their performance, it’s common to inject domain-specific knowledge into LLMs through techniques like Retrieval-Augmented Generation (RAG) or fine-tuning. In this blog post, we explore a fine-tuning technique known as Retrieval Augmented Fine-Tuning (RAFT) and evaluate its effectiveness in adapting pre-trained LLMs for RAG in specialized domains.

RAG Today

RAG is a method to enhance LLMs when dealing with knowledge that is not “baked-in” during the pretraining stage. This often involves specific domains or more up-to-date information. A common way to build a RAG system is to retrieve chunked documents from a vector store and directly inject them into the LLM prompt. For example, a common prompt for the LLM would look like this:  

“Context information is below:\n{contexts}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {question}\nAnswer: “

Check out our RAG in 4 lines of code guide. 

While these systems are easy to build, there may still be room for extra performance to be squeezed out. The debate moves to whether RAG or fine-tuning is more preferable for a given use case. A recent paper called RAFT studies this problem and proposes a novel method to adapt a pre-trained LLM using fine-tuning with retrieval-augmented question answering (QA) data. 

What is RAFT?

Retrieval Augmented Fine-Tuning (RAFT), introduced by Zhang et al, is a method designed to enhance the performance of LLMs in specific domains. RAFT enhances the quality of answers by leveraging generated Chain of Thought (CoT) responses from the provided data. Essentially, RAFT refines a model’s reasoning and answer-generation capabilities by utilizing large pre-trained models. The process involves generating answers with a large model and then fine-tuning these answers on a smaller, more specialized model. This approach helps create high-quality CoT answers, significantly boosting the model’s performance. In doing so, RAFT bridges the gap between general-purpose LLMs and the specialized knowledge required for specific domains.

Figure 1: Example LLM prompt to generate CoT answers with explanations given the relevant context along with a set of distractor documents.  

Why use RAFT?

One of RAFT’s main advantages is its ability to fine-tune chat or instruct models without needing to realign them for chat functionalities. This efficiency saves time and resources that would otherwise be spent on re-aligning the model for conversational purposes. By focusing on domain-specific fine-tuning, RAFT ensures that the LLM can generate more accurate and contextually relevant answers.

The original RAFT paper presents experiments using the Llama2-7B model, demonstrating its effectiveness in various specialized domains. In particular, while using RAG often improves QA performance over only using an LLM, fine-tuning and RAFT consistently outperforms RAG by a larger margin. 

This raises the question: How does RAFT perform with newer models like Llama3-8B? By comparing these models, we can gain insights into the scalability and improvements offered by the latest advancements in LLMs.

How does RAFT perform on newer LLMs?

The published code for RAFT is in this Github repository. We used all the default settings with some small changes:

  • While the paper uses GPT-4 to generate the questions and answers, we chose the Llama3-70B-instruct model as we host it ourselves. 
  • We generated 1 question per chunk and included 3 distractor documents per data point.
  • Instead of supervised fine-tuning, we used LORA. 

For data, we used the HotpotQA dataset, specifically the dev set’s chunked contexts, to create the data points (i.e. questions, CoT answers). Direct questions and answers of the HotpotQA dataset are not included in generated data, so the model won’t memorize them. We created samples with only 100 chunks for the sake of time. The resultant dataset is available on hugging face

Since our focus is on compute-constrained environments, we are interested in models around the 7-8B range or smaller. As such, we’ve selected Llama3 8B and Llama3.1 8B instruct models and their 4-bit quantized variants for our experiments. 

We also compare the results using Llama2-7B-chat as a baseline. For training, we used the TRL SFT trainer. We used lm-evaluation-harness by EleutherAI and evaluated the fine-tuned models on HotpotQA’s validation set (1k samples) on a single NVIDIA A100-SXM4-40GB. 

Results

Figure 2 below shows the F1 scores of the fine-tuned and pretrained models. Indeed, we observe a significant boost in performance from fine-tuning on RAFT-style data for most tested models. Most notably performance increase was over 60% for Llama3 variants and up to over 100% for Llama2 7B. On the other hand, finetuning Llama3.1 8B yields a 16% increase in comparison.

By using 4-bit quantized variants of the Llama3 models, we were able to retain 91-94% of the performance while only using 25% of the GPU memory dedicated to the model weights.

For LoRA configurations, we’ve found that using “all-linear” as target modules to be more effective than using a subset of target modules. Also using a higher LoRA rank (64) we’re able to yield higher scores than using a lower LoRA rank (16). Here we report the best scores from tuning the hyperparameters.

Figure 2: F1 scores of fine-tuned (blue) and pretrained (orange) models evaluated on 1000 samples of HotpotQA dev set

Discussions and Limitations

Initial runs show that the CoT answers seem cutoff when max_new_tokens=512. By setting max_new_tokens=800, we observe that the models were able to generate complete CoT answers. This leads to almost 2x the performance from the lower setting, but on the other hand consumes more time and GPU memory. 

Time and cost are also important factors of consideration. Generating the dataset (100 rows) takes ~30min. On the current inference pricing ($0.0012/request) the dataset costs $0.24 (2 calls/row). Once we have the dataset, finetuning the model on average takes ~10min. On the current deep training pricing ($4/hr), the training costs $0.67. The finetuned model costs less than $1 end-to-end! But of course, some datasets might require different training needs. Tuning the hyperparameters could also add to the cost as well. 

We used Llama3-70B-instruct as the question-answer generator. There are higher-ranking models on the LMSYS Chatbot arena that may yield better quality questions and answers. 

What’s Next?

RAFT seems to be an effective method to adapt smaller LLMs to domain-specific data. From the context chunks, questions and CoT answers can be easily generated via RAFT to form a dataset for finetuning instruct models. This not only removes the need to align a finetuned base model, but also drastically reduces the amount of data needed for finetuning in general. If you want RAFT to be available on the Clarifai platform, send us a message in our Community Discord channel