Stochastic Weight Averaging

SWA

Stochastic Weight Averaging is a training procedure that averages the weights of several networks proposed by SGD in the final phases of training by using a learning rate schedule that encourages exploration of the same basin of attraction.

SGD tends to converge to the edges of these wide and flat regions in the training loss landscape where test distributions rarely align.

The learning rate schedule is a cyclical linear decay from $\alpha_1$ to $\alpha_2$ where the value at iteration $i$ is:

\begin{align} \alpha(i) &= (1 - t(i))\alpha_1 + t(i) \alpha_2 \\ t(i) &= \frac{1}{c}(mod(i - 1, c) + 1) \end{align}

The aggregated weight average is updated during training.

$$ w_{\text{SWA}} \leftarrow \frac{w_{\text{SWA}} \cdot n_{\text{models}} + w}{n_{\text{models}} + 1} $$