Training on your full dataset is rarely the right move. When you have hundreds of millions of rows, naive random sampling introduces class imbalance, wastes compute on easy examples, and blows up memory. A proper sampling pipeline picks the right data at the right time. Here’s how to build one.

Stratified Sampling with scikit-learn

Stratified sampling keeps your class distribution intact. If your dataset is 70% class A and 30% class B, your sample will match that ratio. This matters a lot when you’re downsampling for faster iteration or creating validation splits.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit

# Create a synthetic dataset with imbalanced classes
np.random.seed(42)
n_samples = 100_000
X = np.random.randn(n_samples, 10)
y = np.random.choice([0, 1, 2], size=n_samples, p=[0.7, 0.2, 0.1])

print(f"Full dataset class distribution: {np.bincount(y) / len(y)}")

# Stratified sample: take 10% while preserving class ratios
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
for _, sample_idx in splitter.split(X, y):
    X_sample = X[sample_idx]
    y_sample = y[sample_idx]

print(f"Sample class distribution:       {np.bincount(y_sample) / len(y_sample)}")
print(f"Sample size: {len(X_sample)}")
# Both distributions will be nearly identical: [0.7, 0.2, 0.1]

The key here is StratifiedShuffleSplit over plain train_test_split. It gives you an iterator, so you can generate multiple stratified samples in one pass. Useful for cross-validation or bootstrapping experiments.

Weighted Sampling with PyTorch DataLoader

When classes are severely imbalanced (think 99% negative, 1% positive), stratified sampling alone won’t cut it. You want to oversample the minority class during training. PyTorch’s WeightedRandomSampler handles this at the DataLoader level.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

# Simulate an imbalanced dataset: 95% class 0, 5% class 1
np.random.seed(42)
n_samples = 50_000
features = np.random.randn(n_samples, 8).astype(np.float32)
labels = np.random.choice([0, 1], size=n_samples, p=[0.95, 0.05])

# Compute per-sample weights (inverse class frequency)
class_counts = np.bincount(labels)
class_weights = 1.0 / class_counts
sample_weights = class_weights[labels]

sampler = WeightedRandomSampler(
    weights=torch.from_numpy(sample_weights),
    num_samples=len(labels),  # sample this many per epoch
    replacement=True,         # must be True for oversampling
)

dataset = TensorDataset(
    torch.from_numpy(features),
    torch.from_numpy(labels),
)
loader = DataLoader(dataset, batch_size=256, sampler=sampler)

# Verify: each batch should be roughly 50/50
batch_features, batch_labels = next(iter(loader))
print(f"Batch class distribution: {torch.bincount(batch_labels.int())}")
# Expect approximately [128, 128] instead of [243, 13]

Setting replacement=True is critical. Without it, the sampler can’t draw more samples from the minority class than actually exist. One gotcha: WeightedRandomSampler doesn’t shuffle within the same weight group, so your batches are still somewhat ordered. Pair it with a training loop that handles variable batch compositions.

Reservoir Sampling for Streaming Data

When data arrives as a stream and you don’t know the total size upfront, reservoir sampling gives you a uniform random sample in a single pass with O(k) memory, where k is your desired sample size.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import random
import numpy as np

def reservoir_sample(stream, k, seed=42):
    """Sample k items uniformly from an iterable of unknown length."""
    random.seed(seed)
    reservoir = []
    for i, item in enumerate(stream):
        if i < k:
            reservoir.append(item)
        else:
            j = random.randint(0, i)
            if j < k:
                reservoir[j] = item
    return reservoir

# Simulate a data stream (e.g., reading from a log file line by line)
def data_stream(n=1_000_000):
    rng = np.random.default_rng(0)
    for i in range(n):
        yield {"id": i, "value": rng.normal(), "label": int(rng.random() > 0.8)}

# Sample 10,000 records from the million-record stream
sample = reservoir_sample(data_stream(), k=10_000)
print(f"Sampled {len(sample)} records from stream")
print(f"First record: {sample[0]}")
print(f"Label distribution: {sum(r['label'] for r in sample) / len(sample):.3f}")
# Should be close to 0.2, matching the stream's underlying distribution

This is the algorithm to reach for when you’re processing log files, Kafka topics, or any source where you can’t seek backwards. The math guarantees each element has equal probability of being in the final sample.

Curriculum Sampling: Easy First, Hard Later

Curriculum learning trains on easy examples first, then gradually introduces harder ones. The model converges faster and often reaches better final accuracy. You need a difficulty score for each sample – loss from a previous run, confidence score, or a hand-crafted metric.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

np.random.seed(42)
n_samples = 100_000
features = np.random.randn(n_samples, 16).astype(np.float32)
labels = np.random.randint(0, 5, size=n_samples)

# Difficulty scores: e.g., cross-entropy loss from a previous epoch
# Lower score = easier example
difficulty_scores = np.abs(np.random.randn(n_samples)) + 0.1

def curriculum_sampler(difficulty_scores, epoch, total_epochs, min_fraction=0.3):
    """Return indices sorted by difficulty, expanding the pool each epoch."""
    sorted_indices = np.argsort(difficulty_scores)
    # Linear schedule: start with min_fraction of data, end with 100%
    progress = min(epoch / max(total_epochs - 1, 1), 1.0)
    fraction = min_fraction + (1.0 - min_fraction) * progress
    cutoff = int(len(sorted_indices) * fraction)
    selected = sorted_indices[:cutoff]
    np.random.shuffle(selected)  # shuffle within the selected pool
    return selected

import torch

dataset = TensorDataset(
    torch.from_numpy(features),
    torch.from_numpy(labels),
)

total_epochs = 10
for epoch in range(total_epochs):
    indices = curriculum_sampler(difficulty_scores, epoch, total_epochs, min_fraction=0.3)
    sampler = SubsetRandomSampler(indices)
    loader = DataLoader(dataset, batch_size=512, sampler=sampler)
    if epoch % 3 == 0:
        print(f"Epoch {epoch}: training on {len(indices):,} samples "
              f"({len(indices)/n_samples:.0%} of dataset)")

# Epoch 0: training on 30,000 samples (30% of dataset)
# Epoch 3: training on 53,333 samples (53% of dataset)
# Epoch 6: training on 76,666 samples (77% of dataset)
# Epoch 9: training on 100,000 samples (100% of dataset)

The difficulty scores typically come from a warm-up run. Train for one epoch on a random subset, record per-sample loss, then use those losses as difficulty scores for the curriculum. You can also recompute scores every few epochs to adapt as the model improves.

Fast Sampling on Large Parquet Files with Polars

When your dataset lives in parquet files on disk, you don’t want to load everything into memory just to sample. Polars reads parquet lazily and pushes filters down to the file level.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import polars as pl
import numpy as np

# Create a sample parquet file to work with
np.random.seed(42)
n_rows = 2_000_000
df = pl.DataFrame({
    "feature_0": np.random.randn(n_rows),
    "feature_1": np.random.randn(n_rows),
    "feature_2": np.random.randn(n_rows),
    "label": np.random.choice(["cat", "dog", "bird"], size=n_rows, p=[0.6, 0.3, 0.1]),
    "difficulty": np.random.exponential(1.0, size=n_rows),
})
df.write_parquet("/tmp/training_data.parquet")

# Now sample from it efficiently
lazy_df = pl.scan_parquet("/tmp/training_data.parquet")

# Stratified sample: 5,000 rows per class
# Collect first, then sample within each group
filtered = lazy_df.filter(pl.col("difficulty") < 3.0).collect()

stratified = (
    filtered
    .group_by("label")
    .map_groups(lambda group: group.sample(n=min(5_000, len(group)), seed=42))
)

print(f"Stratified sample shape: {stratified.shape}")
print(stratified.group_by("label").len().sort("label"))
# Each class gets exactly 5,000 rows regardless of original distribution

# Weighted sampling: oversample rare classes
weights = {"cat": 0.1, "dog": 0.3, "bird": 0.9}
all_data = lazy_df.collect()
all_data = all_data.with_columns(
    pl.col("label").replace_strict(weights, default=0.5).alias("sample_weight")
)
# Keep rows where a uniform random draw is below the sample weight
np.random.seed(42)
rand_vals = np.random.rand(len(all_data))
weighted = all_data.filter(pl.Series("rand", rand_vals) < pl.col("sample_weight"))
print(f"Weighted sample size: {len(weighted):,}")
print(weighted.group_by("label").len().sort("label"))

Polars shines here because scan_parquet doesn’t load the file into memory. The query optimizer pushes column selections and row filters into the parquet reader itself. For multi-file datasets, use scan_parquet("data/*.parquet") with glob patterns – Polars handles the fan-out automatically.

Common Errors and Fixes

ValueError: The least populated class in y has only 1 member when using StratifiedShuffleSplit. This means at least one class has too few samples to split. Fix: filter out rare classes before stratifying, or increase your minimum class count threshold.

1
2
3
4
5
6
# Filter classes with fewer than 10 samples before stratified split
from collections import Counter
counts = Counter(y)
valid_mask = np.array([counts[label] >= 10 for label in y])
X_filtered = X[valid_mask]
y_filtered = y[valid_mask]

WeightedRandomSampler produces identical batches every epoch. You forgot to set generator or the global seed resets. PyTorch reseeds the sampler each epoch by default, but if you’re using torch.manual_seed() in your training loop, move it before the loop, not inside.

Polars sample fails with n is larger than the number of rows. When doing stratified sampling with Polars, some groups may have fewer rows than your requested n. Use sample(fraction=...) instead of sample(n=...), or add a .filter(pl.len() >= min_n) before the sample call inside your group-by.

Reservoir sampling gives biased results. Make sure your random index j is drawn from randint(0, i) inclusive on both ends. Off-by-one errors here break the uniform guarantee. The implementation above uses Python’s random.randint(0, i) which is inclusive on both sides.

Out-of-memory on large parquet files. If collect() blows up, add .head(n) or .limit(n) before collecting. Better yet, chain your sampling operations in the lazy frame so Polars can optimize memory usage. Avoid .collect() followed by .sample() – do .sample() before .collect().