Mixed Precision Training¶
Scenario: Training a Large Vision Transformer¶
You're training a ViT-Large model on ImageNet. Full precision (float32) requires 8 GPUs, but mixed precision lets you fit it on 4 GPUs with faster training—cutting costs and time.
What This Is¶
Mixed precision training uses both 16-bit and 32-bit floating point during training to reduce memory usage and speed up computation without sacrificing model quality. Modern GPUs have specialized hardware (tensor cores) that run 16-bit operations significantly faster.
Objectives¶
- Understand how mixed precision reduces memory and speeds up training
- Implement automatic mixed precision (AMP) with GradScaler and autocast
- Choose between float16 and bfloat16 based on hardware
- Handle gradient scaling to prevent underflow
- Scale batch sizes and model sizes within memory constraints
When You Use It¶
- training large models that do not fit in GPU memory at full precision
- speeding up training on GPUs with tensor core support (Volta, Ampere, Hopper)
- scaling batch size within fixed memory constraints
- training production models where wall-clock time matters
Tooling¶
torch.amp.GradScaler— scales loss to prevent gradient underflow in float16torch.amp.autocast— automatically casts operations to float16 where safetorch.float16andtorch.bfloat16— the two common reduced-precision formats
How It Works¶
Mixed precision keeps the master weights in float32 but runs the forward pass and most of the backward pass in float16. A gradient scaler prevents small gradients from underflowing to zero.
from torch.amp import GradScaler, autocast
scaler = GradScaler()
for X_batch, y_batch in train_loader:
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda"):
logits = model(X_batch)
loss = loss_fn(logits, y_batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
float16 vs bfloat16¶
| Format | Range | Precision | Best For |
|---|---|---|---|
float16 |
narrow | higher mantissa bits | older GPUs, needs GradScaler |
bfloat16 |
same as float32 | lower mantissa bits | Ampere+, often no scaler needed |
If your GPU supports bfloat16, it is often simpler because the wider range means gradients rarely underflow:
with autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(X_batch)
loss = loss_fn(logits, y_batch)
loss.backward()
optimizer.step()
Quick Quiz¶
-
What is the main benefit of mixed precision training?
a) Higher model accuracy
b) Reduced memory usage and faster computation
c) Simpler code
d) Better generalization -
When do you need GradScaler?
a) Always with mixed precision
b) Only with float16, not bfloat16
c) Only on older GPUs
d) Never, autocast handles it -
What should you keep in float32 during mixed precision?
a) All weights
b) Loss functions and batch norm statistics
c) Only the optimizer
d) Nothing, autocast handles everything
Validation Pattern¶
Validation also benefits from autocast for speed, but it does not need the scaler:
model.eval()
with torch.no_grad():
with autocast(device_type="cuda"):
logits = model(X_valid)
val_loss = loss_fn(logits, y_valid)
What To Keep In float32¶
Some operations are numerically unstable in float16:
- loss functions (autocast handles this automatically)
- batch normalization running statistics
- small learning rate updates
- operations with large reductions (softmax over long sequences)
The autocast context manager handles most of these cases automatically.
Failure Pattern¶
Enabling mixed precision without GradScaler when using float16, then getting NaN losses because gradients underflowed to zero.
Another failure: assuming mixed precision always helps. On CPUs or older GPUs without tensor cores, it may actually be slower.
Common Mistakes¶
- forgetting
scaler.update()afterscaler.step(), which freezes the scale factor - using
loss.backward()instead ofscaler.scale(loss).backward() - applying gradient clipping outside the scaler workflow
- expecting speed gains on hardware without tensor cores
Gradient Clipping With Mixed Precision¶
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # unscale before clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
Practice¶
- Compare training speed with and without mixed precision on the same model.
- Monitor memory usage with
torch.cuda.max_memory_allocated()in both modes. - Add gradient clipping to a mixed precision loop and verify it works correctly.
- Switch between float16 and bfloat16 and compare stability.
- Explain why the GradScaler is necessary for float16 but often unnecessary for bfloat16.
Checkpoint¶
- [ ] Implement autocast and GradScaler in a training loop
- [ ] Compare memory usage with and without mixed precision
- [ ] Handle gradient clipping correctly with the scaler
- [ ] Choose between float16 and bfloat16 for your hardware
- [ ] Monitor for NaN losses and adjust scaling if needed
Runnable Example¶
Longer Connection¶
Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for the optimizer patterns that interact with mixed precision.
Further Reading¶
- PyTorch Automatic Mixed Precision Documentation
- "Mixed Precision Training" (Micikevicius et al., 2018)
- NVIDIA Tensor Core Documentation for hardware-specific optimizations