Low-Rank Adaptation

An important paradigm for large language models is the idea of pre-training on large-scale general domain data followed by fine tuning on task-specific data. However, fine tuning retrains all model parameters which can become costly. Low-Rank Adaptation freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer, which drastically reduces the memory requirements.

Consider a language modeling problem where the goal is to maximize the conditional probabilities given a task-specific prompt. Given a pre-trained autoregressive model PΦ(y|x), each downstream task is represented by a training dataset of context-target pairs of token sequences Z={(x1,y1),...,(xn,yn)}.

Traditionally the model is initialized to pre-trained weights Φ0 and updated to Φ0+ΔΦ with gradient descent on a conditional language modeling objective.

maxΦ(x,y)Zt=1|y|log (PΦ(yt|x,y<t))

Training on each downstream tasks results in a different set of parameters ΔΦ with the same dimensions as Φ0. Low-Rank Adaptation encodes these task-specific models with a much smaller sized set of parameters Θ.

maxΘ(x,y)Zt=1|y|log (pΦ0+ΔΦ(Θ)(yt|x,y<t))

For a pre-trained weight matrix W0Rd×k, the weight update is constrained through a low-rank matrix decomposition W0+ΔW=W0+BA where BRd×r, ARr×k and the rank rmin(d,k). A is randomly initialized with a Gaussian and B is initialized with zeros such that ΔW=BA=0 at the beginning of training. The modified forward pass yields:

h=W0x+ΔWx=W0x+BAx