Knowledge Distillation

This is a classic transfer learning technique used to train a smaller model to mimic the outputs of a larger model or ensemble. This can allow for a much smaller model to learn complex representations that would have been impossible for to learn from the raw data itself.

The simplest case of distillation is to use the class probabilities produced by the large model as soft targets for the small model as opposed to the hard class label targets. Soft targets provide much more information and less variance in the gradient per training sample than hard targets in states of high entropy. In cases where the larger model is very confident, much of the information about the learned function lies in the ratios of very small probabilities in the soft targets. This information helps define a similarity structure over the data, despite the very little influence that it has over cross-entropy loss. Distillation leverages this information by raising the temperature $T$ of the softmax probabilities.

$$ q_i = \frac{exp(z_i / T)}{\sum_j exp(z_j / T)} $$

The small model could be trained on the original dataset or even an unlabeled transfer set. The objective function could also be modified to add a term that incorporates both the soft targets and the true label. Typically the small model cannot exactly match the soft targets and erring in the direction of the correct predictions is helpful.

Each sample in the transfer set contributes a gradient $\partial C / \partial z_i$ with respect to each logit $z_i$ of the distilled model. If the larger model has lagits $v_i$, the gradient at temperature $T$ is given as:

$$ \frac{\partial C}{\partial z_i} = \frac{1}{T} \left( \frac{exp(z_i / T)}{\sum_j exp(z_j / T)} - \frac{exp(v_i / T)}{\sum_j exp(v_j / T)} \right) $$

If the temperature is high compared to the magnitude of the logits, and the logits are zero-meaned for each transfer sample such that $\sum_j z_j = \sum_j v_j = 0$, then this can be approximated to:

\begin{gather*} \frac{\partial C}{\partial z_i} \approx \frac{1}{T} \left( \frac{1 + z_i / T}{N + \sum_j z_j / T} - \frac{1 + v_i / T}{N + \sum_j v_j / T} \right) \\ \frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2} (z_i - v_i) \end{gather*}