Low-Rank Adaptation

An important paradigm for large language models is the idea of pre-training on large-scale general domain data followed by fine tuning on task-specific data. However, fine tuning retrains all model parameters which can become costly. Low-Rank Adaptation freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer, which drastically reduces the memory requirements.

Consider a language modeling problem where the goal is to maximize the conditional probabilities given a task-specific prompt. Given a pre-trained autoregressive model $P_{\Phi}(y|x)$, each downstream task is represented by a training dataset of context-target pairs of token sequences $\mathcal{Z} = \{(x_1, y_1), ..., (x_n, y_n)\}$.

Traditionally the model is initialized to pre-trained weights $\Phi_0$ and updated to $\Phi_0 + \Delta \Phi$ with gradient descent on a conditional language modeling objective.

$$ \underset{\Phi}{max} \sum_{(x,y) \in \mathcal{Z}} \sum_{t=1}^{|y|} log \ (P_{\Phi}(y_t | x,y_{<t})) $$

Training on each downstream tasks results in a different set of parameters $\Delta \Phi$ with the same dimensions as $\Phi_0$. Low-Rank Adaptation encodes these task-specific models with a much smaller sized set of parameters $\Theta$.

$$ \underset{\Theta}{max} \sum_{(x,y) \in \mathcal{Z}} \sum_{t=1}^{|y|} log \ (p_{\Phi_0 + \Delta \Phi(\Theta)}(y_t | x,y_{<t})) $$

For a pre-trained weight matrix $W_0 \in \mathbb{R}^{d \times k}$, the weight update is constrained through a low-rank matrix decomposition $W_0 + \Delta W = W_0 + BA$ where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times k}$ and the rank $r \ll min(d,k)$. $A$ is randomly initialized with a Gaussian and B is initialized with zeros such that $\Delta W = BA = 0$ at the beginning of training. The modified forward pass yields:

$$ h = W_0x + \Delta W x = W_0 x + BAx $$