Speeding up QAT by 1.89x with LoRA

TLDR

  • Following the success of the quantized Llama 3.2 1B/3B models, we added an entirely PyTorch QAT + LoRA flow in torchtune using torchao APIs
  • Compared to vanilla QAT, the new flow was 1.89x faster and used 36.1% less memory, with slightly improved accuracy and perplexity as an added benefit
  • Compared to raw finetuning, the new flow was 1.69x faster and recovered 52.7% of the accuracy degradation and 68.7% of the perplexity degradation from quantization

Try it out in torchtune today!

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora

What is QAT + LoRA?

Quantization-aware training (QAT) refers to adding fake quantization during training, so as to improve the accuracy of the model when it is ultimately quantized after training. Low-ranked adaptation (LoRA) refers to freezing the original model parameters and adding adapter parameters that will be trained instead, with the main goal of significantly reducing the resource requirements of and speeding up the training process.

Combining these two techniques enables us to get the best of both worlds: improved quantized accuracy + reduced resource usage and faster training. More specifically, linear layers now apply the following transformation (ignoring bias for simplicity):

# Regular training
x -> Wx

# Quantization-aware training (QAT)
x -> fake_quantize(W) @ fake_quantize(x)

# Low-Ranked Adaptation (LoRA)
x -> W_frozen @ x + BAx

# QAT + LoRA
x -> fake_quantize(W_frozen) @ fake_quantize(x) + BAx

Today, the vanilla QAT recipe in torchtune does not take advantage of the speed benefits and memory savings of LoRA. However, since fake quantization in QAT fundamentally adds computational and memory overheads compared to regular training or finetuning, reducing these overheads is an important step towards making QAT more accessible to users.

To this end, we added a new QAT + LoRA recipe that mirrors the existing vanilla LoRA recipe closely, but additionally applies fake quantization to the frozen linear weights and input activations as shown above. This recipe continues to leverage the same torchao QAT APIs as before and can be extended to support any quantization scheme supported by torchao.

Initial Results

Initial experiments for finetuning on Llama3-8B revealed significant performance and/or accuracy benefits over QAT, LoRA and raw finetuning baselines. Highlights include:

  • Compared to QAT: 1.89x faster training throughput and 36.1% less memory, with slightly improved accuracy and perplexity as an added benefit
  • Compared to raw finetuning: 1.69x faster training throughput while recovering 52.7% accuracy degradation and 68.7% perplexity degradation from quantization
  • Compared to LoRA: Recovered 30.9% accuracy degradation and 14.4% perplexity degradation from quantization

These experiments were run on 4 A100 GPUs with 80GB memory each, fine-tuning Llama3-8B on the cleaned alpaca dataset for 1 epoch using: batch_size=1, learning_rate=2e-5, lora_rank=8, lora_alpha=16, weight quantization group_size=32, and activation checkpointing. The quantization scheme was int8 asymmetric per-token dynamic activations + int4 symmetric per-group weights for linear layers.

Next Steps

An important next step is to enable users to leverage the quantized Llama 3.2 1B/3B checkpoints and further fine-tune them on their own custom dataset, making these checkpoints more accessible for domain specific use. To mimic the original training process more closely, we also plan to support optionally unfreezing the original base model weights during fine-tuning, which helped significantly reduce quantization degradation when producing the original quantized Llama 3.2 checkpoints. For further details, please see this issue that tracks these remaining tasks.

1 Like