Skip to content

Mixed Precision Training

Scenario: Training a Large Vision Transformer

You're training a ViT-Large model on ImageNet. Full precision (float32) requires 8 GPUs, but mixed precision lets you fit it on 4 GPUs with faster training—cutting costs and time.

What This Is

Mixed precision training uses both 16-bit and 32-bit floating point during training to reduce memory usage and speed up computation without sacrificing model quality. Modern GPUs have specialized hardware (tensor cores) that run 16-bit operations significantly faster.

Objectives

  • Understand how mixed precision reduces memory and speeds up training
  • Implement automatic mixed precision (AMP) with GradScaler and autocast
  • Choose between float16 and bfloat16 based on hardware
  • Handle gradient scaling to prevent underflow
  • Scale batch sizes and model sizes within memory constraints

When You Use It

  • training large models that do not fit in GPU memory at full precision
  • speeding up training on GPUs with tensor core support (Volta, Ampere, Hopper)
  • scaling batch size within fixed memory constraints
  • training production models where wall-clock time matters

Tooling

  • torch.amp.GradScaler — scales loss to prevent gradient underflow in float16
  • torch.amp.autocast — automatically casts operations to float16 where safe
  • torch.float16 and torch.bfloat16 — the two common reduced-precision formats

How It Works

Mixed precision keeps the master weights in float32 but runs the forward pass and most of the backward pass in float16. A gradient scaler prevents small gradients from underflowing to zero.

from torch.amp import GradScaler, autocast

scaler = GradScaler()

for X_batch, y_batch in train_loader:
    optimizer.zero_grad(set_to_none=True)

    with autocast(device_type="cuda"):
        logits = model(X_batch)
        loss = loss_fn(logits, y_batch)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

float16 vs bfloat16

Format Range Precision Best For
float16 narrow higher mantissa bits older GPUs, needs GradScaler
bfloat16 same as float32 lower mantissa bits Ampere+, often no scaler needed

If your GPU supports bfloat16, it is often simpler because the wider range means gradients rarely underflow:

with autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = model(X_batch)
    loss = loss_fn(logits, y_batch)

loss.backward()
optimizer.step()

Quick Quiz

  1. What is the main benefit of mixed precision training?
    a) Higher model accuracy
    b) Reduced memory usage and faster computation
    c) Simpler code
    d) Better generalization

  2. When do you need GradScaler?
    a) Always with mixed precision
    b) Only with float16, not bfloat16
    c) Only on older GPUs
    d) Never, autocast handles it

  3. What should you keep in float32 during mixed precision?
    a) All weights
    b) Loss functions and batch norm statistics
    c) Only the optimizer
    d) Nothing, autocast handles everything

Validation Pattern

Validation also benefits from autocast for speed, but it does not need the scaler:

model.eval()
with torch.no_grad():
    with autocast(device_type="cuda"):
        logits = model(X_valid)
        val_loss = loss_fn(logits, y_valid)

What To Keep In float32

Some operations are numerically unstable in float16:

  • loss functions (autocast handles this automatically)
  • batch normalization running statistics
  • small learning rate updates
  • operations with large reductions (softmax over long sequences)

The autocast context manager handles most of these cases automatically.

Failure Pattern

Enabling mixed precision without GradScaler when using float16, then getting NaN losses because gradients underflowed to zero.

Another failure: assuming mixed precision always helps. On CPUs or older GPUs without tensor cores, it may actually be slower.

Common Mistakes

  • forgetting scaler.update() after scaler.step(), which freezes the scale factor
  • using loss.backward() instead of scaler.scale(loss).backward()
  • applying gradient clipping outside the scaler workflow
  • expecting speed gains on hardware without tensor cores

Gradient Clipping With Mixed Precision

scaler.scale(loss).backward()
scaler.unscale_(optimizer)  # unscale before clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()

Practice

  1. Compare training speed with and without mixed precision on the same model.
  2. Monitor memory usage with torch.cuda.max_memory_allocated() in both modes.
  3. Add gradient clipping to a mixed precision loop and verify it works correctly.
  4. Switch between float16 and bfloat16 and compare stability.
  5. Explain why the GradScaler is necessary for float16 but often unnecessary for bfloat16.

Checkpoint

  • [ ] Implement autocast and GradScaler in a training loop
  • [ ] Compare memory usage with and without mixed precision
  • [ ] Handle gradient clipping correctly with the scaler
  • [ ] Choose between float16 and bfloat16 for your hardware
  • [ ] Monitor for NaN losses and adjust scaling if needed

Runnable Example

This example stays local-only for now because the browser runner does not yet include PyTorch.

Longer Connection

Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for the optimizer patterns that interact with mixed precision.

Further Reading