Skip to content

Vision Augmentation and Shift Robustness

What This Is

Small image shifts can expose whether a model learned a stable visual pattern or only a brittle pixel layout. Augmentation is one practical way to teach the model the invariance you want.

The useful distinction is between "more data" and "better invariance." Augmentation is not just a quantity trick; it is a way to encode the type of shift you expect at test time.

For applied work, the key question is simple: does your model still work when the input moves a little?

When You Use It

  • checking whether a vision baseline is too brittle
  • adding lightweight augmentation before moving to a heavier model
  • comparing clean accuracy against perturbed accuracy
  • deciding whether the problem is really about translation, framing, or local shape

Start Here

If you want a fast local test, load_digits is enough. It gives a small handwritten-digit dataset that is ideal for shift experiments because you can change the position of the digit without changing the class.

Typical first move:

  1. load the images and labels
  2. split with stratify=y
  3. fit a linear baseline on flattened pixels
  4. create shifted test images
  5. measure the clean-versus-shifted gap

That gap tells you more than a single accuracy number.

Tooling

  • load_digits
  • zero-padded shift functions, np.stack, and reshape
  • train_test_split
  • LogisticRegression
  • Pipeline with StandardScaler
  • balanced_accuracy_score, classification_report, confusion_matrix
  • small torch.nn.Conv2d models when a linear view is too brittle
  • torch.utils.data.DataLoader
  • model.train(), model.eval(), torch.no_grad(), torch.inference_mode()
  • torchvision.transforms.Compose, RandomAffine, RandomHorizontalFlip, CenterCrop, Normalize, ToTensor

Minimal Example

def shift_image(image, dx=1, dy=1):
    shifted = np.zeros_like(image)
    x_from = max(0, -dx)
    x_to = image.shape[1] - max(0, dx)
    y_from = max(0, -dy)
    y_to = image.shape[0] - max(0, dy)
    shifted[y_from + dy : y_to + dy, x_from + dx : x_to + dx] = image[y_from:y_to, x_from:x_to]
    return shifted

For load_digits, that creates a more honest translation test than np.roll, which wraps pixels around the opposite edge and can create unrealistic artifacts.

Clean Baseline

from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

digits = load_digits()
x = digits.images
y = digits.target

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=0, stratify=y
)

linear = make_pipeline(
    StandardScaler(),
    LogisticRegression(max_iter=2000)
)
linear.fit(x_train.reshape(len(x_train), -1), y_train)

Why this matters:

  • flattening turns the image into a fast pixel baseline
  • StandardScaler keeps the pixel ranges comparable before a linear model
  • stratify=y keeps the split honest

Shifted Evaluation

shifted_test = np.stack([shift_image(img, dx=1, dy=1) for img in x_test])

clean_pred = linear.predict(x_test.reshape(len(x_test), -1))
shift_pred = linear.predict(shifted_test.reshape(len(shifted_test), -1))

clean_score = balanced_accuracy_score(y_test, clean_pred)
shift_score = balanced_accuracy_score(y_test, shift_pred)
robustness_gap = clean_score - shift_score

What to look for:

  • a small gap means the model is fairly stable
  • a large gap means the representation is brittle
  • a good clean score alone is not enough if the shifted score collapses

If you want a more detailed read, add a confusion matrix and a class report. That often shows whether the shift mostly hurts a few digits or breaks the whole view.

Robustness Evaluation Grid

Do not trust one perturbation direction only. Report at least:

  • clean score
  • mean score across a small grid of plausible shifts
  • worst score on that grid

That turns “the model survived one translation” into a more honest robustness statement.

A good small grid is:

  • (-1, 0)
  • (1, 0)
  • (0, -1)
  • (0, 1)
  • (1, 1)

If the worst-case score collapses while the mean looks fine, the augmentation recipe is still too narrow.

Augmentation Recipes

For real image tasks, augmentation should match the shift you expect.

Useful torchvision patterns:

  • Compose to keep the transforms in one place
  • RandomAffine for small translation, rotation, or scale changes
  • CenterCrop when the framing varies
  • Normalize after tensor conversion
  • ToTensor or the torchvision v2 tensor conversion path before normalization

Practical rule:

  • use RandomAffine when a small location change should not change the label
  • use RandomHorizontalFlip only when left-right mirroring is label-safe
  • do not use flip augmentation for tasks where orientation changes the class

That last point matters for digits. Flipping a digit can turn it into a different class or an invalid example, so the transform must match the task, not the default recipe.

Worked Pattern

shifted_train = np.stack([shift_image(img, dx=1, dy=1) for img in x_train])
augmented_train = np.concatenate([x_train, shifted_train], axis=0)
augmented_labels = np.concatenate([y_train, y_train], axis=0)
linear.fit(augmented_train.reshape(len(augmented_train), -1), augmented_labels)

What this pattern gives you:

  • a cleaner way to test invariance without changing the target
  • a quick baseline for seeing whether the shift is the real failure mode
  • a controlled comparison between clean and shifted data

Linear Baseline Vs CNN

Start with the linear model if:

  • the task is small
  • you need a fast sanity check
  • you want to know whether the feature view is the real bottleneck

Move to a small CNN if:

  • the linear baseline stays brittle under small shifts
  • local shape matters more than global pixel position
  • you want the model to learn translation-tolerant features directly

Architecture x Augmentation Checks

The useful ablation is not only “augmentation on or off.” It is the small grid:

  • linear model without augmentation
  • linear model with augmentation
  • CNN without augmentation
  • CNN with augmentation

That tells you whether the gain comes from the representation, the augmentation, or the interaction between them.

A small CNN is often enough:

import torch
from torch import nn

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        return self.net(x)

The important shape rule is simple:

  • linear models want flattened arrays
  • Conv2d wants (batch, channels, height, width)
  • for grayscale digits, add one channel dimension before the CNN

Training And Validation

When you train the CNN, keep the mode switches explicit:

  • model.train() during training
  • model.eval() during validation
  • torch.no_grad() or torch.inference_mode() when scoring

That matters because dropout and batch norm behave differently in train versus eval mode. A robustness claim is weak if the model is accidentally evaluated in the wrong mode.

Failure Pattern

Reporting only the clean test score, then discovering too late that one-pixel shifts wipe out the result.

Another failure is augmenting the validation data and then using that score to claim robustness. The model should learn from augmentation, not be graded on it.

Another trap is using a flip augmentation because it looks standard. If the label changes under the flip, the augmentation is wrong even if the code runs.

Practice

  1. Fit a clean baseline and compare it against a shifted test set.
  2. Add simple shift augmentation and measure how much the robustness gap changes.
  3. Explain why augmentation belongs on the training side only.
  4. Compare a linear model and a small CNN before moving to a heavier architecture.
  5. Describe what kind of shift your augmentation actually teaches.
  6. State one case where augmentation would not help because the issue is not shift sensitivity.
  7. Check whether the worst errors are concentrated in a few confusing classes or spread across all classes.
  8. Decide which shift direction matters most before you start tuning the model.

Runnable Example

Open the matching example in AI Academy and run it from the platform.

Run the same idea in the browser:

Inspect the clean score, the shifted score, and the gap before deciding whether the problem is the model, the representation, or the augmentation recipe.

Library Notes

  • load_digits is a compact local dataset that is enough to test shift robustness quickly.
  • a zero-padded shift helper is the safest way to create a controlled translation test.
  • reshape lets you move between image tensors and flattened linear features.
  • LogisticRegression is the best first check for whether the representation is the bottleneck.
  • Pipeline with StandardScaler keeps preprocessing and the model together.
  • RandomAffine is a direct way to train for small translation or rotation changes.
  • Compose keeps your torchvision transforms explicit and readable.
  • Conv2d is the first deep model worth trying when local visual patterns matter.
  • DataLoader helps when you move from a toy baseline to real training loops.

Questions To Ask

  1. Is the model failing because of a shift or because the class is hard?
  2. Did augmentation close the gap on the shift you actually care about?
  3. Would a larger model help more than a better augmentation recipe?
  4. Is the robustness gain stable across more than one shift direction?
  5. What is the smallest invariance you need for this task?
  6. Does the linear baseline already solve the clean problem well enough to make a CNN unnecessary?
  7. If the CNN helps, is the gain coming from translation tolerance or just extra capacity?
  8. Are you improving the metric you will actually be judged on, or only the easy one?

Longer Connection

Continue with Vision and Audio Workflows for the full clean-versus-shifted comparison across logistic and CNN recipes.