Disclaimer: This guide is for educational purposes only. Do not use any model built from this tutorial for clinical diagnosis, treatment decisions, or any real medical application. Medical AI systems require regulatory approval (FDA, CE marking) and extensive clinical validation before deployment.
Why CheXNet Still Matters#
CheXNet, published by Stanford in 2017, showed that a DenseNet-121 fine-tuned on chest X-rays could match radiologist-level performance on pneumonia detection. The architecture is straightforward: take a pretrained DenseNet-121, swap the final classifier for 14 sigmoid outputs (one per pathology), and train with binary cross-entropy loss.
The same pattern applies today. You get a strong feature extractor from ImageNet pretraining, adapt it for multi-label medical classification, and use Grad-CAM to show where the model is looking. Here’s how to build it from scratch.
Setting Up the Model#
Start by loading a pretrained DenseNet-121 and replacing its classifier head. The original CheXNet paper used 14 pathology labels from the ChestX-ray14 dataset: Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis, Pleural Thickening, and Hernia.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| import torch
import torch.nn as nn
from torchvision import models
PATHOLOGIES = [
"Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
"Mass", "Nodule", "Pneumonia", "Pneumothorax",
"Consolidation", "Edema", "Emphysema", "Fibrosis",
"Pleural_Thickening", "Hernia"
]
class CheXNet(nn.Module):
def __init__(self, num_classes=14, pretrained=True):
super().__init__()
weights = models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
self.densenet = models.densenet121(weights=weights)
# DenseNet-121 classifier is a single Linear layer
# Replace it for multi-label output
in_features = self.densenet.classifier.in_features # 1024
self.densenet.classifier = nn.Sequential(
nn.Linear(in_features, num_classes),
nn.Sigmoid()
)
def forward(self, x):
return self.densenet(x)
model = CheXNet(num_classes=14, pretrained=True)
print(f"Classifier input features: {model.densenet.features[-1].num_features}")
|
The Sigmoid() activation is key here. Unlike softmax (which forces outputs to sum to 1), sigmoid lets each pathology output an independent probability. A single X-ray can show both pneumonia and pleural effusion simultaneously.
Medical images need special handling. Chest X-rays are typically grayscale, but DenseNet-121 expects 3-channel RGB input. You also need to be careful with augmentations – aggressive rotations or color jitter can distort diagnostic features.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
| from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
train_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.Grayscale(num_output_channels=3), # Convert grayscale to 3-channel
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet stats work well here
std=[0.229, 0.224, 0.225]
),
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
class ChestXrayDataset(Dataset):
"""Expects a CSV with columns: image_path, and one column per pathology (0 or 1)."""
def __init__(self, csv_path, image_dir, transform=None):
self.df = pd.read_csv(csv_path)
self.image_dir = image_dir
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = os.path.join(self.image_dir, row["image_path"])
image = Image.open(img_path).convert("L") # Load as grayscale
if self.transform:
image = self.transform(image)
labels = torch.tensor(
row[PATHOLOGIES].values.astype(float), dtype=torch.float32
)
return image, labels
|
A few notes on these transforms. Grayscale(num_output_channels=3) duplicates the single channel across all three, which is the standard approach for feeding grayscale images into RGB-pretrained models. ImageNet normalization stats still work because the pretrained weights expect that distribution. Horizontal flipping is safe for chest X-rays since anatomy is roughly symmetric, but avoid vertical flipping – upside-down X-rays are not a realistic augmentation.
Training with BCE Loss#
Multi-label classification uses BCELoss (or BCEWithLogitsLoss if you skip the sigmoid in the model). Since we already have sigmoid in our model, we use plain BCELoss.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
| def train_chexnet(model, train_loader, val_loader, epochs=10, lr=1e-4):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=2
)
for epoch in range(epochs):
model.train()
running_loss = 0.0
num_batches = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
num_batches += 1
avg_train_loss = running_loss / num_batches
# Validation
model.eval()
val_loss = 0.0
val_batches = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_batches += 1
avg_val_loss = val_loss / val_batches
scheduler.step(avg_val_loss)
print(f"Epoch {epoch+1}/{epochs} — "
f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
return model
# Example usage (assuming you have your data ready)
train_dataset = ChestXrayDataset("train_labels.csv", "images/", train_transforms)
val_dataset = ChestXrayDataset("val_labels.csv", "images/", val_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
trained_model = train_chexnet(model, train_loader, val_loader, epochs=10)
|
The learning rate of 1e-4 with Adam works well as a starting point for fine-tuning pretrained models. ReduceLROnPlateau drops the rate by 10x if validation loss stalls for 2 epochs. For the ChestX-ray14 dataset (112,000+ images), 10 epochs is usually enough to converge.
Generating Grad-CAM Heatmaps#
Grad-CAM visualizes which regions of the image drove a particular prediction. For DenseNet-121, you hook into the final batch normalization layer after the last dense block (model.densenet.features[-1]) and compute gradients of the target class with respect to those feature maps.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
| import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
class GradCAM:
def __init__(self, model):
self.model = model
self.model.eval()
self.feature_maps = None
self.gradients = None
# Hook into the final BatchNorm (norm5) after the last DenseBlock
target_layer = self.model.densenet.features[-1] # norm5 (BatchNorm2d)
target_layer.register_forward_hook(self._save_features)
target_layer.register_full_backward_hook(self._save_gradients)
def _save_features(self, module, input, output):
self.feature_maps = output.detach()
def _save_gradients(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_tensor, target_class):
"""Generate Grad-CAM heatmap for a specific pathology class."""
device = next(self.model.parameters()).device
input_tensor = input_tensor.unsqueeze(0).to(device)
output = self.model(input_tensor)
self.model.zero_grad()
target_score = output[0, target_class]
target_score.backward()
# Global average pool the gradients
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # (1, C, 1, 1)
cam = (weights * self.feature_maps).sum(dim=1, keepdim=True) # (1, 1, H, W)
cam = torch.relu(cam)
cam = cam.squeeze().cpu().numpy()
# Normalize to [0, 1]
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam, output[0].detach().cpu().numpy()
def visualize_gradcam(image_path, model, target_class, transform):
"""Show original X-ray with Grad-CAM overlay for a pathology."""
original = Image.open(image_path).convert("L")
input_tensor = transform(original)
grad_cam = GradCAM(model)
heatmap, predictions = grad_cam.generate(input_tensor, target_class)
# Resize heatmap to image dimensions
heatmap_resized = np.array(
Image.fromarray((heatmap * 255).astype(np.uint8)).resize(
original.size, resample=Image.BILINEAR
)
) / 255.0
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original, cmap="gray")
axes[0].set_title("Original X-ray")
axes[0].axis("off")
axes[1].imshow(heatmap_resized, cmap="jet")
axes[1].set_title(f"Grad-CAM: {PATHOLOGIES[target_class]}")
axes[1].axis("off")
axes[2].imshow(original, cmap="gray")
axes[2].imshow(heatmap_resized, cmap="jet", alpha=0.4)
axes[2].set_title(f"Overlay (p={predictions[target_class]:.3f})")
axes[2].axis("off")
plt.tight_layout()
plt.savefig("gradcam_output.png", dpi=150, bbox_inches="tight")
plt.show()
# Example: visualize pneumonia attention (index 6)
# visualize_gradcam("sample_xray.png", trained_model, target_class=6, transform=val_transforms)
|
The heatmap highlights where the model focuses for a given pathology. For pneumonia, you should see activation over lung opacities. For cardiomegaly, expect attention around the heart silhouette. If the model highlights irrelevant areas (like the image borders or text annotations), that’s a strong signal of data leakage or poor training.
Standard accuracy doesn’t work well for multi-label problems. Use per-class AUROC instead, which is the standard metric for CheXNet-style models.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
| from sklearn.metrics import roc_auc_score
def evaluate_auroc(model, data_loader, device):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in data_loader:
images = images.to(device)
outputs = model(images)
all_preds.append(outputs.cpu().numpy())
all_labels.append(labels.numpy())
all_preds = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
aurocs = {}
for i, pathology in enumerate(PATHOLOGIES):
try:
aurocs[pathology] = roc_auc_score(all_labels[:, i], all_preds[:, i])
except ValueError:
aurocs[pathology] = float("nan") # Only one class present
mean_auroc = np.nanmean(list(aurocs.values()))
print(f"\nMean AUROC: {mean_auroc:.4f}\n")
for pathology, auc in aurocs.items():
print(f" {pathology:<22s} AUROC: {auc:.4f}")
return aurocs
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# evaluate_auroc(trained_model, val_loader, device)
|
The original CheXNet paper reported a mean AUROC of ~0.841 across all 14 pathologies. With modern training tricks (cosine annealing, mixed precision, better augmentation), you can push that higher.
Common Errors and Fixes#
RuntimeError: Expected 3-channel input but got 1 channel
Your images are loading as single-channel grayscale. Add transforms.Grayscale(num_output_channels=3) to your transform pipeline, or convert to RGB when loading: Image.open(path).convert("RGB").
Loss is NaN after a few batches
This usually means your sigmoid outputs are hitting 0 or 1 exactly, causing log(0) in BCE. Switch from nn.BCELoss() with sigmoid in the model to nn.BCEWithLogitsLoss() without sigmoid – it’s numerically stable because it combines the log-sigmoid computation.
All predictions are 0.5 (model not learning)
Check your labels. ChestX-ray14 uses string labels like "Atelectasis|Effusion" that need to be split and one-hot encoded. If your label tensor is all zeros because of a parsing bug, the model has nothing to learn from.
CUDA out of memory during Grad-CAM
Grad-CAM requires gradients, so torch.no_grad() won’t work here. Reduce batch size to 1 for Grad-CAM inference, or move to CPU. You only need Grad-CAM for visualization, not training.
Poor AUROC on specific pathologies (Hernia, Nodule)
Some pathologies have very few positive examples in ChestX-ray14. Hernia has about 200 positive cases out of 112,000 images. Use weighted BCE loss to handle class imbalance:
1
2
3
4
5
6
| # Compute positive weights from training label frequencies
pos_counts = train_df[PATHOLOGIES].sum()
neg_counts = len(train_df) - pos_counts
pos_weights = torch.tensor((neg_counts / pos_counts).values, dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights.to(device))
|
Images look wrong after transforms
To debug, visualize a batch before training. Undo the normalization and check the images look like actual X-rays:
1
2
3
4
5
6
7
8
| def show_batch(images, labels, mean, std):
img = images[0].clone()
for c in range(3):
img[c] = img[c] * std[c] + mean[c]
plt.imshow(img.permute(1, 2, 0).numpy()[:, :, 0], cmap="gray")
active = [PATHOLOGIES[i] for i, v in enumerate(labels[0]) if v == 1]
plt.title(", ".join(active) if active else "No Finding")
plt.show()
|
Where to Get the Data#
The ChestX-ray14 dataset from NIH contains 112,120 frontal-view chest X-rays with 14 disease labels. You can download it from the NIH Clinical Center website. CheXpert from Stanford is another option with 224,316 images and uncertainty labels. Both datasets require agreeing to a data use agreement before download.
For quick experimentation, start with a small subset. The full ChestX-ray14 dataset is about 42 GB of images.