Batch Normalisation
Last updated
Copyright Continuum Labs - 2023
Last updated
This November 2018 paper provides an extensive empirical analysis of batch normalization (BN) to better understand its effectiveness in deep learning.
They show that the primary benefit of BN is that it enables training with larger learning rates, which leads to faster convergence and better generalization.
They demonstrate that without BN, large learning rates can cause the loss and activations to diverge and grow uncontrollably with network depth.
They explain that BN avoids this problem by normalizing activations to zero-mean and unit standard deviation, which allows for larger gradient updates and may help avoid sharp local minima.
They illustrate various ways in which gradients and activations of deep unnormalized networks can be ill-behaved, and they connect these findings to recent results in random matrix theory.
The paper begins by disentangling the benefits of BN.
The authors train a 110-layer ResNet on CIFAR-10 with and without BN, using various learning rates. They find that BN networks with low learning rates perform similarly to unnormalized networks, suggesting that the higher learning rates enabled by BN are the primary driver of its benefits.
The authors then analyze why BN allows for higher learning rates.
They observe that gradients in unnormalized networks have heavy tails and are orders of magnitude larger than in BN networks. They show that for unnormalized networks, only small learning rates reduce the loss, while BN networks can use a much wider range. They define network divergence as the loss increasing beyond a certain threshold and show that activations and their variance grow rapidly with depth during divergence in unnormalized networks. BN prevents this by normalizing activations at each layer.
Next, the authors examine the behavior of activations and gradients at initialization. They find that even without divergence, the means and variances of activations grow exponentially with depth in unnormalized networks. They also show that in the final layer, unnormalized networks tend to be biased towards certain classes regardless of the input, leading to highly correlated gradients across examples. In contrast, BN networks have more input-dependent gradients.
For convolutional layers, the authors decompose the gradient of each kernel and find that for unnormalized networks, the gradients are similar across spatial dimensions and examples, encoding little input-dependent information. They also show that certain input and output channels consistently have much larger gradients than others.
Finally, the authors connect their findings to recent results in random matrix theory, suggesting that traditional initialization schemes may be unsuitable for deep networks without BN.
In summary, this paper provides a thorough empirical investigation into how batch normalization improves deep network training. The key insights are that BN enables larger learning rates by preventing activation explosion and that it leads to more input-dependent gradients. These findings enhance our understanding of this widely-used technique.
At its heart, batch normalisation functions as a form of regularisation, designed to expedite the training process and foster the convergence of models at an accelerated pace.
The technique is characterised by its computation of running averages across feature dimensions throughout the training cycle. This process introduces a novel form of non-differentiable state within the model's architecture, necessitating adept management to harness its full potential.
Incorporating batch normalisation into a model invariably alters its architectural blueprint.
Traditional models without this technique exhibit a linear structure, where inputs traverse through various layers—such as dense or activation layers—without modification.
The introduction of batch normalisation layers marks a pivotal shift, imbuing the model with the ability to normalise neuron outputs to a standard range. This adjustment significantly enhances the stability and efficiency of the training process.
A critical aspect of batch normalisation is its dual behavior during the distinct phases of training and inference.
In the training phase, the technique leverages statistics derived from the current data batch, optimising the model's performance with real-time feedback.
Conversely, during inference, batch normalisation shifts its reliance to aggregated statistics gathered throughout the training period. This bifurcation is essential for the technique's correct application and the model's operational integrity.
The introduction of batch normalisation layers brings forth additional states in the form of running averages, encompassing batch means and variances. These states coexist with the model's parameters, requiring meticulous management through the model's lifecycle—including checkpointing phases and the transition between training and evaluation modes.
Embedding batch normalisation into the training loop necessitates careful consideration of the batch statistics, ensuring their accurate update and maintenance across iterations.
Adapting the training function to accommodate these layers is crucial for the seamless propagation of batch statistics through the model.
During evaluation, it is imperative to use the accumulated statistics rather than those derived from the current batch. This approach guarantees the consistency and reliability of the model's performance metrics, underscoring the importance of a tailored evaluation function that upholds the model's inference fidelity.
Batch normalisation emerges as a cornerstone technique for the advancement of machine learning, particularly in the optimization of large language models.
Its successful implementation hinges on a thorough understanding of its mechanisms, the strategic alteration of model architectures it entails, and the nuanced differentiation between training and inference modes.
By navigating these considerations, practitioners can unlock faster, more efficient training processes and achieve robust model performance in real-world applications.