Debugging Deep Learning Models¶
Scenario: Training Loss Won't Decrease¶
You've set up a ResNet for pneumonia detection, but after 20 epochs the training loss stays at 0.693 (random chance for binary classification). Something is broken—learn systematic debugging to find whether it's the data pipeline, model architecture, optimizer settings, or loss function.
Learning Objectives¶
By the end of this module (30-40 minutes), you should be able to: - Diagnose why training isn't working using systematic checks. - Test if a model can overfit a single batch (sanity check). - Inspect gradient flow to catch vanishing/exploding gradients. - Verify data pipeline correctness (shapes, labels, normalization). - Use visualization tools to debug activations and weights.
Prerequisites: Basic PyTorch training loops, understanding of backpropagation. Difficulty: Intermediate.
What This Is¶
Debugging deep learning is different from debugging traditional code. The model might "run" without errors but fail to learn. This topic teaches a systematic workflow for diagnosing training failures before you waste time on architecture changes or hyperparameter sweeps.
The key insight: most training failures come from a few common mistakes, and checking them in order saves hours.
When You Use It¶
- training loss doesn't decrease
- validation accuracy stays near random chance
- loss explodes or becomes NaN
- model trains on simple tasks but fails on real data
- gradient norms are suspiciously small or large
The Debugging Ladder¶
Follow this sequence when training fails:
1. Overfit a Single Batch (The Sanity Check)¶
If your model can't memorize one batch, the architecture or loss function is broken.
# Take one batch and try to get perfect loss
model.train()
inputs, labels = next(iter(train_loader))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for i in range(200):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"Step {i}: Loss = {loss.item():.4f}")
# Loss should drop to near zero
# If not: check model architecture, loss function, or data shapes
What to expect: Loss should drop close to zero within 100-200 steps.
If it doesn't:
- Check that outputs and labels have compatible shapes
- Verify the loss function matches the task (e.g., CrossEntropyLoss for classification)
- Inspect if the model has enough capacity (at least a few layers)
2. Check Gradient Flow¶
Gradients should flow backward through all layers. If they vanish or explode, learning fails.
# After loss.backward(), inspect gradients
def check_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
print(f"{name}: grad_norm = {grad_norm:.6f}")
else:
print(f"{name}: NO GRADIENT")
check_gradients(model)
What to look for: - Vanishing gradients: norms < 1e-6 → add batch norm, check activation functions, reduce depth - Exploding gradients: norms > 100 → use gradient clipping, lower learning rate - No gradient: layer is disconnected or frozen
3. Verify Data Pipeline¶
Bad data shapes or incorrect labels silently break training.
# Check one batch
inputs, labels = next(iter(train_loader))
print(f"Input shape: {inputs.shape}")
print(f"Label shape: {labels.shape}")
print(f"Input range: [{inputs.min():.3f}, {inputs.max():.3f}]")
print(f"Unique labels: {labels.unique()}")
# Visualize a sample
import matplotlib.pyplot as plt
plt.imshow(inputs[0].permute(1, 2, 0).cpu()) # For images
plt.title(f"Label: {labels[0].item()}")
plt.show()
Common issues: - Labels are one-hot but loss expects class indices (or vice versa) - Images not normalized to [0, 1] or [-1, 1] - Channels in wrong order (CHW vs HWC) - Wrong number of classes in final layer
4. Check Model Output Shape¶
The model's final layer must match the loss function's expectation.
model.eval()
with torch.no_grad():
outputs = model(inputs)
print(f"Model output shape: {outputs.shape}")
print(f"Expected shape: (batch_size, num_classes)")
# For classification, check if logits or probabilities
print(f"Output range: [{outputs.min():.3f}, {outputs.max():.3f}]")
For CrossEntropyLoss: Output should be raw logits (no softmax), shape (batch, num_classes).
For BCEWithLogitsLoss: Output should be logits, shape (batch, 1) or (batch,).
5. Monitor Learning Rate¶
If the LR is too high, loss oscillates. If too low, training is slow or stalls.
# Track LR
for epoch in range(epochs):
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}: LR = {current_lr:.6f}")
# ... training loop ...
Rules of thumb:
- Start with 1e-3 for Adam/AdamW
- Start with 1e-1 for SGD with momentum
- Use learning rate finder if unsure
6. Visualize Training Curves¶
Plot loss and accuracy over epochs to spot patterns.
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
Patterns to recognize: - Flat loss: model isn't learning → check LR, data, or architecture - Diverging train/val: overfitting → add regularization - Oscillating loss: LR too high → reduce LR or add gradient clipping
Common Failure Modes¶
Loss is NaN¶
Causes: - Exploding gradients - Learning rate too high - Numerical instability in loss (e.g., log(0))
Fixes:
- Add gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
- Lower learning rate by 10x
- Check for division by zero or invalid operations
Training Accuracy ≈ Random Chance¶
Causes: - Wrong loss function for the task - Labels misaligned with outputs - Model too shallow for the task - Data not preprocessed correctly
Fixes: - Overfit a single batch to rule out architecture issues - Print and inspect labels vs predictions for one batch - Check if data augmentation is too aggressive
Validation Loss Increases While Training Loss Decreases¶
Cause: Overfitting
Fixes: - Add dropout or weight decay - Use data augmentation - Reduce model capacity - Get more training data
Training is Extremely Slow¶
Causes:
- Inefficient data loading (no num_workers or pin_memory)
- Batch size too small
- Model too large for the task
Fixes:
- Set num_workers=4 and pin_memory=True in DataLoader
- Increase batch size (if GPU memory allows)
- Use mixed precision training (torch.cuda.amp)
Debugging Checklist¶
Before changing architecture or hyperparameters, verify:
- [ ] Model can overfit a single batch
- [ ] Gradients are flowing (not vanishing/exploding)
- [ ] Data shapes and labels are correct
- [ ] Loss function matches the task
- [ ] Learning rate is in a reasonable range
- [ ] Training curves are being monitored
- [ ] train() and eval() modes are set correctly
Tools for Deep Inspection¶
1. TensorBoard¶
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/experiment_1')
# Log scalars
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
# Log histograms of weights
for name, param in model.named_parameters():
writer.add_histogram(name, param, epoch)
writer.close()
View with: tensorboard --logdir=runs
2. Weights & Biases (wandb)¶
import wandb
wandb.init(project="debugging-demo")
wandb.watch(model, log="all") # Log gradients and parameters
for epoch in range(epochs):
# ... training ...
wandb.log({"loss": loss, "accuracy": acc, "epoch": epoch})
3. Print Intermediate Activations¶
def register_hooks(model):
activations = {}
def get_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.ReLU):
layer.register_forward_hook(get_activation(name))
return activations
activations = register_hooks(model)
outputs = model(inputs)
# Check if activations are dead (all zeros)
for name, act in activations.items():
print(f"{name}: mean = {act.mean():.4f}, std = {act.std():.4f}")
Quick Quiz¶
-
What does it mean if a model can't overfit a single batch?
a) Learning rate is too low
b) The architecture or loss function is broken
c) Regularization is too strong
d) Data augmentation is needed -
What's the first thing to check if loss becomes NaN?
a) Add more layers
b) Check for exploding gradients and reduce LR
c) Increase batch size
d) Switch to a different optimizer -
If gradients are < 1e-6 in early layers, what's likely happening?
a) Overfitting
b) Vanishing gradients
c) Exploding gradients
d) Perfect convergence
Practice¶
- Create a deliberately broken model (e.g., wrong loss function) and debug it using the ladder.
- Train a model that overfits—diagnose it by comparing train vs. validation curves.
- Implement gradient norm tracking and identify if/when gradients vanish or explode.
- Use TensorBoard to visualize weight histograms and spot dead neurons.
Next Steps¶
After mastering debugging: - Read Optimizers and Regularization to fix overfitting - Read Learning Rate Schedulers to stabilize training - Try PyTorch Training Loops for clean loop patterns
Summary¶
The debugging ladder: 1. Overfit one batch (sanity check) 2. Check gradient flow 3. Verify data pipeline 4. Check model output shapes 5. Monitor learning rate 6. Visualize training curves
Most common issues: - Wrong loss function - Bad learning rate - Data shape mismatch - Vanishing/exploding gradients
Golden rule: Fix the data and architecture before tuning hyperparameters.