Why Vision Transformers

Vision Transformers (ViT) split images into patches, flatten them into sequences, and process them with the same transformer architecture that powers LLMs. No convolutions, no pooling layers – just attention. Google introduced ViT in 2020 and it matched or beat CNNs on ImageNet when pretrained on enough data.

For most image classification tasks today, ViT is the default choice. It generalizes better than ResNet on diverse datasets, and Hugging Face makes it trivial to load pretrained checkpoints.

Here is the fastest way to classify an image:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch

# Load pretrained ViT (trained on ImageNet-21k, fine-tuned on ImageNet-1k)
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model.eval()

# Load and preprocess an image
image = Image.open("dog.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

# Run inference
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(-1).item()

print(model.config.id2label[predicted_class])
# Output: "golden retriever"

Three imports, a few lines, and you have a working image classifier. The ViTImageProcessor handles resizing to 224x224, normalization, and conversion to tensors – you don’t need to write any of that yourself.

How Preprocessing Works

ViT expects a very specific input format. The processor normalizes pixel values using ImageNet statistics and resizes images to the model’s expected resolution. Skip any of this and you get garbage predictions.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import ViTImageProcessor
from PIL import Image
import numpy as np

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

image = Image.open("cat.jpg").convert("RGB")

# See exactly what the processor does
inputs = processor(images=image, return_tensors="pt")

print(f"Input shape: {inputs['pixel_values'].shape}")
# Output: torch.Size([1, 3, 224, 224])

print(f"Pixel range: [{inputs['pixel_values'].min():.2f}, {inputs['pixel_values'].max():.2f}]")
# Output: Pixel range: [-1.85, 2.15] (normalized, not 0-255)

# The processor applies these transforms:
# 1. Resize to 224x224
# 2. Normalize with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# 3. Convert to float32 tensor

If you are building a custom pipeline or using a dataloader, you can replicate this with torchvision.transforms:

1
2
3
4
5
6
7
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Use the Hugging Face processor when possible. It stays in sync with the model checkpoint and handles edge cases like different image modes.

Batch Inference

Classifying one image at a time is slow. Batch your inputs to take advantage of GPU parallelism.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
from pathlib import Path

model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model.eval()
model.to("cuda")

# Load multiple images
image_paths = list(Path("./photos").glob("*.jpg"))
images = [Image.open(p).convert("RGB") for p in image_paths]

# Process as a batch
inputs = processor(images=images, return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    predictions = outputs.logits.argmax(-1)

for path, pred in zip(image_paths, predictions):
    label = model.config.id2label[pred.item()]
    print(f"{path.name}: {label}")

Picking the Right ViT Variant

Not all ViTs are equal. The model name tells you everything: vit-base-patch16-224 means base size, 16x16 patches, 224px input.

ModelParamsTop-1 AccSpeedBest For
vit-base-patch16-22486M81.1%FastGeneral use, fine-tuning
vit-large-patch16-224304M82.6%ModerateHigher accuracy needs
vit-base-patch32-22488M78.7%FastestLow-latency apps
vit-large-patch16-512304M83.8%SlowHigh-res inputs

Go with vit-base-patch16-224 unless you have a specific reason not to. It strikes the best balance between accuracy and speed. The patch32 variants are faster but noticeably less accurate – the larger patches lose fine-grained detail.

Fine-Tuning on a Custom Dataset

The pretrained model knows ImageNet’s 1000 classes. Your dataset probably has different classes. Fine-tuning replaces the classification head and trains on your data while keeping most of the learned features.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer
from datasets import load_dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score

# Load your dataset (using a Hugging Face dataset as example)
dataset = load_dataset("beans")  # 3 classes: angular_leaf_spot, bean_rust, healthy
train_ds = dataset["train"]
val_ds = dataset["validation"]

# Processor for the base model
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

# Define label mapping
labels = train_ds.features["labels"].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

# Load model with a new classification head
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # replaces the pretrained head
)

# Preprocessing function
def preprocess(batch):
    inputs = processor(images=batch["image"], return_tensors="pt")
    inputs["labels"] = batch["labels"]
    return inputs

# Apply preprocessing
train_ds = train_ds.with_transform(preprocess)
val_ds = val_ds.with_transform(preprocess)

# Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, preds)}

# Training config
training_args = TrainingArguments(
    output_dir="./vit-beans",
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_steps=10,
    remove_unused_columns=False,
    fp16=torch.cuda.is_available(),
)

# Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

A few things worth noting here. The ignore_mismatched_sizes=True flag is critical – without it, PyTorch throws an error because the pretrained head has 1000 outputs and your model needs a different number. Set the learning rate low (2e-5) so you don’t destroy the pretrained features. Five epochs is usually enough for small datasets; watch the validation accuracy and stop early if it plateaus.

Evaluating Your Model

After fine-tuning, check how well the model actually performs. Don’t just look at accuracy – examine per-class metrics to find weak spots.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sklearn.metrics import classification_report
import torch
import numpy as np

model.eval()
model.to("cuda")

all_preds = []
all_labels = []

val_ds_raw = load_dataset("beans")["validation"]

for example in val_ds_raw:
    inputs = processor(images=example["image"], return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
        pred = outputs.logits.argmax(-1).item()
    all_preds.append(pred)
    all_labels.append(example["labels"])

print(classification_report(all_labels, all_preds, target_names=labels))
# Output:
#                       precision    recall  f1-score   support
#   angular_leaf_spot       0.95      0.93      0.94        44
#           bean_rust       0.94      0.96      0.95        44
#             healthy       0.98      0.98      0.98        44
#            accuracy                           0.95       132

If a class has low recall, you probably need more training examples for that class or harder augmentations to force the model to learn its distinguishing features.

Common Errors

RuntimeError: The size of tensor a (X) must match the size of tensor b (1000) – You loaded a pretrained model without setting num_labels and ignore_mismatched_sizes=True. The classification head still expects 1000 ImageNet classes. Pass both arguments to from_pretrained().

ValueError: Could not find image processor class in the image processor config – You are using an old version of transformers. Update with pip install --upgrade transformers. Versions before 4.30 used ViTFeatureExtractor instead of ViTImageProcessor.

Validation accuracy stuck at random chance. Two likely culprits: your learning rate is too high (try 1e-5 instead of 2e-5), or your preprocessing is wrong. Make sure you are using the same processor for training and evaluation. Mismatched normalization kills accuracy silently.

CUDA out of memory during training. Drop the batch size to 8 or 4. If that still fails, enable gradient checkpointing with model.gradient_checkpointing_enable() – it trades compute for memory.

Images load as RGBA or grayscale. Always call .convert("RGB") on PIL images before passing them to the processor. ViT expects 3-channel RGB input. A 4-channel PNG or single-channel grayscale image will cause shape mismatches.

When ViT Is Not the Right Pick

ViT needs a lot of data to train from scratch. If you have fewer than 1000 images and cannot use a pretrained checkpoint, a ResNet or EfficientNet with transfer learning might converge faster. ViT also struggles with very high resolution images where patch count explodes – for those cases, look at Swin Transformer which uses hierarchical windows.

For edge deployment where latency matters more than accuracy, MobileViT or EfficientFormer give you transformer-style accuracy in a smaller package.

But for the vast majority of image classification projects with a GPU and a few hundred labeled examples? Fine-tune vit-base-patch16-224 and move on. It is the safe, proven default.