Skip to content

Batch Normalization and Initialization

Scenario: Training a Deep Image Classifier

You're training a ResNet-like network for image classification, but gradients vanish in deeper layers. Use batch normalization and proper initialization to stabilize training, enabling faster convergence and deeper architectures.

What This Is

Batch normalization stabilizes training by normalizing layer inputs, and initialization sets the starting point for learning. Together they determine whether a network trains smoothly or stalls before it starts.

When You Use It

  • training deep networks where gradients vanish or explode
  • speeding up convergence so you can use higher learning rates
  • stabilizing training when adding more layers
  • debugging a network that trains on one configuration but fails on another

Tooling

  • nn.BatchNorm1d and nn.BatchNorm2d for batch normalization
  • nn.LayerNorm for sequence or transformer models
  • nn.init.xavier_uniform_ and nn.init.xavier_normal_ for sigmoid/tanh activations
  • nn.init.kaiming_uniform_ and nn.init.kaiming_normal_ for ReLU activations
  • nn.init.zeros_ and nn.init.ones_ for bias and gain terms

How Batch Norm Works

Batch normalization normalizes activations across the batch dimension during training, then uses running statistics during evaluation.

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(128, 64),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

Key behaviors:

  • in model.train() mode, it computes batch statistics and updates running mean/variance
  • in model.eval() mode, it uses the stored running statistics
  • if you forget model.eval(), validation becomes noisy because it uses batch-level statistics

When To Use Which Normalization

Layer Best For Notes
BatchNorm1d MLPs with batch dimension needs batch size > 1
BatchNorm2d CNNs normalizes per channel
LayerNorm transformers, sequence models normalizes per sample, batch-size independent
GroupNorm small batch training compromise between batch and layer norm

Initialization Matters

The default PyTorch initialization works for many architectures, but when training stalls or gradients vanish, explicit initialization can help.

Xavier (Glorot) — for sigmoid/tanh

nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)

Kaiming (He) — for ReLU

nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
nn.init.zeros_(layer.bias)

Custom initialization loop

def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(init_weights)

Failure Pattern

Training a deep network with default initialization and no normalization, then blaming the architecture when gradients collapse after a few layers.

Another trap: using batch normalization with a batch size of 1 during training, which makes the batch statistics meaningless.

Common Mistakes

  • forgetting to switch to eval() mode, making batch norm use per-batch statistics during validation
  • using BatchNorm in a model that processes single samples at inference time without switching to eval mode
  • initializing all weights to zero, which makes every neuron learn the same thing
  • mixing Xavier init with ReLU activations, where Kaiming is more appropriate

Practice

  1. Add batch normalization to a simple MLP and compare convergence speed.
  2. Swap BatchNorm1d for LayerNorm and explain when each is preferable.
  3. Initialize a network with Kaiming init versus default and compare the first-epoch loss.
  4. Show what happens when you forget model.eval() with batch norm during validation.
  5. Apply a custom initialization function and inspect the weight distributions.

Case Study: Batch Norm in ResNet

ResNet's success in deep networks is partly due to batch normalization, which stabilized training and allowed 100+ layers. This revolutionized computer vision by enabling very deep architectures.

Expanded Quick Quiz

Why does batch normalization help deep networks?

Answer: It normalizes layer inputs, reducing internal covariate shift and allowing higher learning rates.

When should you use LayerNorm over BatchNorm?

Answer: In transformers or when batch sizes are small/variable, as LayerNorm is batch-size independent.

How does Kaiming initialization differ from Xavier?

Answer: Kaiming accounts for ReLU's non-linearity, using a different variance scale for better convergence.

In the image classifier scenario, why is initialization critical?

Answer: Poor initialization can cause vanishing gradients in deep networks, preventing learning.

Progress Checkpoint

  • [ ] Added batch normalization to a network and observed convergence.
  • [ ] Compared BatchNorm vs. LayerNorm in different scenarios.
  • [ ] Applied proper initialization (Kaiming/Xavier) and checked weight distributions.
  • [ ] Switched to eval mode for validation and verified behavior.
  • [ ] Answered quiz questions without peeking.

Milestone: Complete this to unlock "Learning Rate Schedulers" in the Deep Learning track. Share your normalization comparison in the academy Discord!

Further Reading

  • "Batch Normalization: Accelerating Deep Network Training" paper.
  • PyTorch Initialization Docs.
  • Normalization techniques in deep learning.

Runnable Example

This example stays local-only for now because the browser runner does not yet include PyTorch.

Longer Connection

Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for regularization techniques that complement batch normalization.