Classify a Document in 10 Lines

The Document Image Transformer (DiT) from Microsoft treats document pages as images and classifies them into types like invoice, receipt, letter, or form. No OCR required – the model learns visual layout patterns directly from the pixel data.

Here is how to classify a scanned document with a pretrained DiT model:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch

model_name = "microsoft/dit-base-finetuned-rvlcdip"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval()

image = Image.open("scan_001.png").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    predicted_class = outputs.logits.argmax(-1).item()

label = model.config.id2label[predicted_class]
print(f"Document type: {label}")
# Output: "Document type: invoice"

The microsoft/dit-base-finetuned-rvlcdip checkpoint is trained on RVL-CDIP, a dataset of 400,000 grayscale document images across 16 categories: letter, memo, email, file folder, form, handwritten, invoice, advertisement, budget, news article, presentation, scientific publication, questionnaire, resume, scientific report, and specification. It hits around 92% accuracy out of the box.

Why DiT Over OCR-Based Classifiers

Traditional document classification extracts text with OCR, then feeds it to a text classifier. That pipeline is brittle. OCR errors propagate, handwritten documents choke the recognizer, and you need language-specific models for every locale.

DiT skips all of that. It is a vision-only model based on the BEiT architecture. It learns from the visual structure of documents – where headers sit, how tables are laid out, the density of text blocks. This makes it language-agnostic and robust to poor scan quality.

LayoutLMv3 is the other strong option. It combines visual features with text embeddings (from an internal OCR step), so it performs better when text content matters for classification. Use DiT when you want a pure vision approach with no OCR dependency. Use LayoutLMv3 when the text content distinguishes document types that look visually similar.

Preprocessing Document Images

Document scans come in all shapes. Some are skewed, some have borders, some are 300 DPI TIFFs. The processor handles resizing and normalization, but you should clean up the input for best results.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from PIL import Image, ImageOps

def preprocess_document(path: str, target_size: int = 224) -> Image.Image:
    """Clean up a scanned document for classification."""
    img = Image.open(path)

    # Convert to RGB (handles grayscale, RGBA, CMYK)
    img = img.convert("RGB")

    # Auto-orient based on EXIF data (common with phone scans)
    img = ImageOps.exif_transpose(img)

    # Remove black borders from scanning artifacts
    gray = img.convert("L")
    bbox = gray.point(lambda x: 255 if x > 10 else 0).getbbox()
    if bbox:
        img = img.crop(bbox)

    return img

A few tips on input quality:

  • Resolution: DiT was trained on 224x224 images. The processor downscales for you, but feeding in extremely low-resolution scans (under 100 DPI) will lose structural details.
  • Color mode: Always convert to RGB. Grayscale inputs work but the model expects 3 channels.
  • Multi-page PDFs: Extract pages individually with pdf2image or PyMuPDF and classify each page separately.

Fine-Tune on Custom Document Types

The RVL-CDIP categories probably don’t match your use case. If you need to classify purchase orders vs. packing slips vs. customs declarations, you need to fine-tune.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
import torch

model_name = "microsoft/dit-base-finetuned-rvlcdip"
processor = AutoImageProcessor.from_pretrained(model_name)

# Load your labeled dataset (expects "image" and "label" columns)
# Structure: dataset_dir/train/invoice/001.png, dataset_dir/train/receipt/002.png, etc.
dataset = load_dataset("imagefolder", data_dir="./document_dataset")

# Define labels
labels = dataset["train"].features["label"].names
label2id = {name: i for i, name in enumerate(labels)}
id2label = {i: name for i, name in enumerate(labels)}

# Load model with new classification head
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # Required -- replaces the old head
)

# Preprocessing function
def transform(examples):
    images = [img.convert("RGB") for img in examples["image"]]
    inputs = processor(images=images, return_tensors="pt")
    inputs["labels"] = examples["label"]
    return inputs

dataset = dataset.with_transform(transform)

training_args = TrainingArguments(
    output_dir="./dit-document-classifier",
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    remove_unused_columns=False,
)

import numpy as np
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model("./dit-document-classifier/best")

A few hundred labeled examples per class is usually enough to get strong accuracy when fine-tuning from the RVL-CDIP checkpoint, especially if your document types are visually similar to the original 16 categories. If your domain is very different (medical forms, architectural plans), expect to need more data. If you have fewer than 50 samples per class, consider augmenting with rotations, slight skew, and brightness jitter.

Batch Processing a Directory of Scans

For production workloads, you need to process thousands of files. Here is a batch pipeline that classifies every document in a directory and writes results to a CSV:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import csv
from pathlib import Path
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch

model_path = "./dit-document-classifier/best"
processor = AutoImageProcessor.from_pretrained(model_path)
model = AutoModelForImageClassification.from_pretrained(model_path)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

scan_dir = Path("./incoming_scans")
results = []

for img_path in sorted(scan_dir.glob("*.png")):
    image = Image.open(img_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        confidence, predicted = probs.max(-1)

    label = model.config.id2label[predicted.item()]
    results.append({
        "file": img_path.name,
        "label": label,
        "confidence": f"{confidence.item():.4f}",
    })
    print(f"{img_path.name} -> {label} ({confidence.item():.2%})")

# Write results to CSV
with open("classification_results.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["file", "label", "confidence"])
    writer.writeheader()
    writer.writerows(results)

For large batches, speed this up by batching multiple images into a single forward pass:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from itertools import islice

def batched(iterable, n):
    it = iter(iterable)
    while batch := list(islice(it, n)):
        yield batch

for batch_paths in batched(sorted(scan_dir.glob("*.png")), 32):
    images = [Image.open(p).convert("RGB") for p in batch_paths]
    inputs = processor(images=images, return_tensors="pt").to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        confidences, predictions = probs.max(-1)

    for path, conf, pred in zip(batch_paths, confidences, predictions):
        label = model.config.id2label[pred.item()]
        results.append({"file": path.name, "label": label, "confidence": f"{conf.item():.4f}"})

This runs one forward pass per batch of 32 instead of per image, cutting wall-clock time dramatically on a GPU.

Common Errors and Fixes

RuntimeError: expected scalar type Float but found Byte

You passed raw pixel data without running it through the processor. The processor converts uint8 pixel values to normalized floats. Always use the AutoImageProcessor before feeding images to the model.

ValueError: ignore_mismatched_sizes not recognized

You are on an old version of transformers. Update with pip install --upgrade transformers. The ignore_mismatched_sizes parameter was added in v4.22.

Model predicts the same class for every input

This usually means you forgot ignore_mismatched_sizes=True when loading a model with a different number of labels. Without it, the old classification head weights get loaded into the new head, and the model converges to the majority class. Check your training logs – if training loss is not decreasing, the head is likely misconfigured.

PIL.UnidentifiedImageError: cannot identify image file

The file is corrupted or is not actually an image (sometimes PDFs get mixed in). Wrap your loading in a try/except and log the failures:

1
2
3
4
try:
    image = Image.open(path).convert("RGB")
except Exception as e:
    print(f"Skipping {path}: {e}")

Out of memory on GPU

DiT-base is ~86M parameters and fits comfortably on a 4GB GPU for inference. If you are fine-tuning and hitting OOM, reduce per_device_train_batch_size to 8 or 4. If that is still too much, enable gradient checkpointing by adding gradient_checkpointing=True in your TrainingArguments.

Low accuracy on colored or glossy documents

RVL-CDIP is a grayscale dataset. If your documents are colorful (marketing materials, brochures), the pretrained model may struggle. Fine-tuning on your actual document distribution fixes this quickly.