The Quick Version

Ray gives you three libraries that handle the hardest parts of distributed ML: Ray Train for distributed training, Ray Serve for scalable inference, and Ray Data for streaming data pipelines. Instead of wrestling with torch.distributed, NCCL configs, and custom serving infrastructure, you wrap your existing PyTorch code and let Ray handle the coordination.

Install Ray and get a distributed training job running in under 20 lines of config:

1
pip install "ray[train]" torch torchvision
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_func():
    import torch
    import ray.train.torch

    model = torch.nn.Linear(10, 1)
    model = ray.train.torch.prepare_model(model)

    for epoch in range(5):
        loss = torch.tensor(1.0)  # your real training step here
        ray.train.report({"loss": loss.item(), "epoch": epoch})

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")

That launches your training function across 4 GPU workers with DistributedDataParallel already configured. No init_process_group, no rank management, no manual device placement.

Ray Train: Distributed Training Without the Boilerplate

The key to Ray Train is two wrapper functions: prepare_model and prepare_data_loader. They handle the DDP wrapping, distributed sampling, and device placement that you’d otherwise configure manually.

Here’s a realistic training script that fine-tunes a ResNet on FashionMNIST across multiple GPUs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import tempfile
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train
import ray.train.torch
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

def train_func():
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # prepare_model wraps in DDP and places on correct device
    model = ray.train.torch.prepare_model(model)

    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    transform = Compose([ToTensor(), Normalize((0.2860,), (0.3203,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # prepare_data_loader adds DistributedSampler and handles device placement
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.module.state_dict(), os.path.join(tmpdir, "model.pt"))
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(tmpdir),
            )

scaling_config = ScalingConfig(num_workers=4, use_gpu=True)
run_config = RunConfig(storage_path="/tmp/ray_results", name="fashion_mnist_run")

trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)
result = trainer.fit()

Batch Size Math

One thing that trips people up: Ray Train splits your data across workers using a DistributedSampler. If you set batch_size=128 and run 4 workers, each worker processes 128 samples per step, making your effective global batch size 512. Adjust your learning rate accordingly (linear scaling rule: multiply LR by the number of workers).

Multi-Node Storage

On a single machine, local paths work fine for checkpoints. The moment you go multi-node, you need shared storage. Pass a cloud path:

1
run_config = RunConfig(storage_path="s3://my-bucket/ray-results", name="training_run")

Without this, you’ll get checkpoint errors on worker nodes that can’t access the head node’s filesystem.

Ray Serve: Production Inference at Scale

Ray Serve handles the serving side. For LLM inference specifically, Ray Serve LLM wraps vLLM and gives you an OpenAI-compatible API with tensor parallelism, autoscaling, and request batching out of the box.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from ray.serve.llm import LLMConfig, build_openai_app

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="my-llama",
        model_source="meta-llama/Llama-3.1-8B-Instruct",
    ),
    accelerator_type="A100",
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1,
            max_replicas=4,
        )
    ),
    engine_kwargs=dict(
        max_model_len=8192,
        tensor_parallel_size=2,
    ),
)

app = build_openai_app({"llm_configs": [llm_config]})

Deploy it with serve run serve_config:app and hit it with any OpenAI-compatible client:

1
2
3
4
5
6
7
8
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed")
response = client.chat.completions.create(
    model="my-llama",
    messages=[{"role": "user", "content": "Explain gradient checkpointing in 2 sentences."}],
)
print(response.choices[0].message.content)

For larger models that don’t fit on a single node, combine tensor and pipeline parallelism. Set tensor_parallel_size to the number of GPUs per node and pipeline_parallel_size to the number of nodes. A 70B model across 2 nodes with 4 GPUs each would use tensor_parallel_size=4, pipeline_parallel_size=2.

Ray Data: Streaming Preprocessing

Ray Data connects your raw data to your training loop with streaming execution. Instead of loading everything into memory first, it processes batches on-the-fly and keeps your GPUs fed.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import ray

ds = ray.data.read_parquet("s3://my-bucket/training-data/")

ds = ds.map(lambda row: {"text": row["text"].lower(), "label": row["label"]})
ds = ds.filter(lambda row: len(row["text"]) > 10)

# Stream directly to training workers
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
    datasets={"train": ds},
)

Inside your training function, access the data through Ray’s iterator:

1
2
3
4
5
6
7
def train_func():
    train_ds = ray.train.get_dataset_shard("train")
    for epoch in range(10):
        for batch in train_ds.iter_torch_batches(batch_size=64):
            # batch is a dict of tensors
            inputs = batch["text"]
            labels = batch["label"]

This keeps CPU preprocessing and GPU training running in parallel. Ray handles the backpressure automatically, so slow preprocessing doesn’t cause GPU idle time (it buffers ahead), and fast preprocessing doesn’t blow up memory.

Common Errors and Fixes

“No available node types can fulfill resource request”

You asked for more GPUs than your cluster has. Check ray status and compare against your ScalingConfig. If you’re on a single machine with 2 GPUs, num_workers=4 with use_gpu=True will hang.

Environment variables not reaching workers

Variables set on the driver don’t propagate. Pass them explicitly:

1
ray.init(runtime_env={"env_vars": {"HF_TOKEN": "hf_xxx", "CUDA_VISIBLE_DEVICES": "0,1"}})

Checkpointing fails on multi-node clusters

If you see FileNotFoundError during checkpoint saving, you’re using a local path on a multi-node setup. Switch to S3, GCS, or an NFS mount in your RunConfig.

OOM on workers

Each worker loads a full copy of the model. A 7B model needs ~14GB in fp16 per worker. If you’re running 4 workers on a node with 4x 16GB GPUs, there’s almost no room for activations. Use gradient checkpointing or reduce num_workers.

Placement group conflicts

If tasks inside a Ray Train worker launch nested tasks, you’ll see scheduling hangs. Add scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=None) to break out of the parent placement group.

When to Use Ray vs. Raw PyTorch

Use raw torch.distributed when you have a single multi-GPU machine, a simple training loop, and no serving requirements. It’s less abstraction and fewer dependencies.

Use Ray when you need any of these: multi-node training, fault tolerance (Ray restarts failed workers), a unified pipeline from data loading through training to serving, or autoscaling inference deployments. Ray’s overhead on a single machine is minimal, but the real payoff comes when you’re coordinating across nodes or when you want training and serving to share the same cluster.