Batch Normalization and Initialization¶
Scenario: Training a Deep Image Classifier¶
You're training a ResNet-like network for image classification, but gradients vanish in deeper layers. Use batch normalization and proper initialization to stabilize training, enabling faster convergence and deeper architectures.
What This Is¶
Batch normalization stabilizes training by normalizing layer inputs, and initialization sets the starting point for learning. Together they determine whether a network trains smoothly or stalls before it starts.
When You Use It¶
- training deep networks where gradients vanish or explode
- speeding up convergence so you can use higher learning rates
- stabilizing training when adding more layers
- debugging a network that trains on one configuration but fails on another
Tooling¶
nn.BatchNorm1dandnn.BatchNorm2dfor batch normalizationnn.LayerNormfor sequence or transformer modelsnn.init.xavier_uniform_andnn.init.xavier_normal_for sigmoid/tanh activationsnn.init.kaiming_uniform_andnn.init.kaiming_normal_for ReLU activationsnn.init.zeros_andnn.init.ones_for bias and gain terms
How Batch Norm Works¶
Batch normalization normalizes activations across the batch dimension during training, then uses running statistics during evaluation.
import torch.nn as nn
model = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 10),
)
Key behaviors:
- in
model.train()mode, it computes batch statistics and updates running mean/variance - in
model.eval()mode, it uses the stored running statistics - if you forget
model.eval(), validation becomes noisy because it uses batch-level statistics
When To Use Which Normalization¶
| Layer | Best For | Notes |
|---|---|---|
BatchNorm1d |
MLPs with batch dimension | needs batch size > 1 |
BatchNorm2d |
CNNs | normalizes per channel |
LayerNorm |
transformers, sequence models | normalizes per sample, batch-size independent |
GroupNorm |
small batch training | compromise between batch and layer norm |
Initialization Matters¶
The default PyTorch initialization works for many architectures, but when training stalls or gradients vanish, explicit initialization can help.
Xavier (Glorot) — for sigmoid/tanh¶
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
Kaiming (He) — for ReLU¶
nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
nn.init.zeros_(layer.bias)
Custom initialization loop¶
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
model.apply(init_weights)
Failure Pattern¶
Training a deep network with default initialization and no normalization, then blaming the architecture when gradients collapse after a few layers.
Another trap: using batch normalization with a batch size of 1 during training, which makes the batch statistics meaningless.
Common Mistakes¶
- forgetting to switch to
eval()mode, making batch norm use per-batch statistics during validation - using
BatchNormin a model that processes single samples at inference time without switching to eval mode - initializing all weights to zero, which makes every neuron learn the same thing
- mixing Xavier init with ReLU activations, where Kaiming is more appropriate
Practice¶
- Add batch normalization to a simple MLP and compare convergence speed.
- Swap
BatchNorm1dforLayerNormand explain when each is preferable. - Initialize a network with Kaiming init versus default and compare the first-epoch loss.
- Show what happens when you forget
model.eval()with batch norm during validation. - Apply a custom initialization function and inspect the weight distributions.
Case Study: Batch Norm in ResNet¶
ResNet's success in deep networks is partly due to batch normalization, which stabilized training and allowed 100+ layers. This revolutionized computer vision by enabling very deep architectures.
Expanded Quick Quiz¶
Why does batch normalization help deep networks?
Answer: It normalizes layer inputs, reducing internal covariate shift and allowing higher learning rates.
When should you use LayerNorm over BatchNorm?
Answer: In transformers or when batch sizes are small/variable, as LayerNorm is batch-size independent.
How does Kaiming initialization differ from Xavier?
Answer: Kaiming accounts for ReLU's non-linearity, using a different variance scale for better convergence.
In the image classifier scenario, why is initialization critical?
Answer: Poor initialization can cause vanishing gradients in deep networks, preventing learning.
Progress Checkpoint¶
- [ ] Added batch normalization to a network and observed convergence.
- [ ] Compared BatchNorm vs. LayerNorm in different scenarios.
- [ ] Applied proper initialization (Kaiming/Xavier) and checked weight distributions.
- [ ] Switched to eval mode for validation and verified behavior.
- [ ] Answered quiz questions without peeking.
Milestone: Complete this to unlock "Learning Rate Schedulers" in the Deep Learning track. Share your normalization comparison in the academy Discord!
Further Reading¶
- "Batch Normalization: Accelerating Deep Network Training" paper.
- PyTorch Initialization Docs.
- Normalization techniques in deep learning.
Runnable Example¶
Longer Connection¶
Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for regularization techniques that complement batch normalization.