MLX-LoRA-Studio

Quantization-Aware Training (QAT)

In-training quantization · qat_enable: true

Install a straight-through-estimator fake-quantise hook on every nn.Linear so the adapter trains as if it will be deployed quantised. Keeps outlier channels tame; only affects LoRA/DoRA adapters.

Overview

mlx-lm-lora supports two distinct kinds of quantisation that are easy to confuse. This page covers Quantization-Aware Training (QAT). (See QLoRA for the load-time variant.)

QAT is a small hook installed on every nn.Linear after the first optimiser step. The hook fake-quantises the weight on the way into the forward pass (straight-through estimator), so the model trains as if it would be quantised at inference time. The optimiser still sees and updates the full-precision weights, so the gradient is unaffected.

It composes with load-time quantisation: a typical “QLoRA + QAT” run loads the base model in 4-bit and then trains with the QAT hook enabled so the LoRA updates are robust to that 4-bit precision. QAT is only effective for the SFT/DPO/ORPO trainers (the others do not call _install_qat_hooks).

Intuition

Objective (math)

Symmetric fake quantise (applied inside the forward):

# Same arithmetic as load-time, but at runtime on the current weight tensor:
Ŵ   =  s · clip( round( W / s ),  q_min,  q_max )

# The forward uses Ŵ; the backward uses an STE:
∂ℒ / ∂W   =  ∂ℒ / ∂Ŵ                       (identity — gradient flows through)

The hook is implemented as:

self.weight  =  w  +  stop_gradient( quantize(w)  −  w )     # forward sees Ŵ
out          =  original_forward(self, x)
self.weight  =  w                                             # restore for optimiser

The stop_gradient around ( quantize(w) − w ) is the STE. The + w outside it means the forward value is exactly Ŵ, the backward value is 1, and the optimiser only ever touches the full-precision w.

What the settings change

Setting Default What it actually changes
qat_enable false Install the STE fake-quantise hook on every nn.Linear after the first optimiser step. Only effective for SFT/DPO/ORPO.
qat_bits 8 Bit-width used by the hook. Match the inference quantisation: deploy at 4-bit ⇒ qat_bits=4; deploy at 8-bit ⇒ qat_bits=8.
qat_group_size 64 Group size used by the hook. 0 or negative = per-tensor. Match the deployment group size; 64 or 128 are common.
qat_start_step 1 First optimiser step on which to install the hook. Set higher if your first few steps see NaN gradients.
qat_interval 1 Re-apply the QAT projection every N steps. Default projects every step; raise to e.g. 4 if projection shows up in your profile.

Which to pick

In the app

On the Train tab, the QAT section is a toggle plus four fields:

The section is only effective for SFT, DPO, and ORPO (the only trainers that install the hook). For the “QLoRA + QAT” recipe, set Quantization to 4-bit in the Fine-tune section and enable QAT here.

Tips & gotchas

References

See also