The Quick Version
Ray gives you three libraries that handle the hardest parts of distributed ML: Ray Train for distributed training, Ray Serve for scalable inference, and Ray Data for streaming data pipelines. Instead of wrestling with torch.distributed, NCCL configs, and custom serving infrastructure, you wrap your existing PyTorch code and let Ray handle the coordination.
Install Ray and get a distributed training job running in under 20 lines of config:
| |
| |
That launches your training function across 4 GPU workers with DistributedDataParallel already configured. No init_process_group, no rank management, no manual device placement.
Ray Train: Distributed Training Without the Boilerplate
The key to Ray Train is two wrapper functions: prepare_model and prepare_data_loader. They handle the DDP wrapping, distributed sampling, and device placement that you’d otherwise configure manually.
Here’s a realistic training script that fine-tunes a ResNet on FashionMNIST across multiple GPUs:
| |
Batch Size Math
One thing that trips people up: Ray Train splits your data across workers using a DistributedSampler. If you set batch_size=128 and run 4 workers, each worker processes 128 samples per step, making your effective global batch size 512. Adjust your learning rate accordingly (linear scaling rule: multiply LR by the number of workers).
Multi-Node Storage
On a single machine, local paths work fine for checkpoints. The moment you go multi-node, you need shared storage. Pass a cloud path:
| |
Without this, you’ll get checkpoint errors on worker nodes that can’t access the head node’s filesystem.
Ray Serve: Production Inference at Scale
Ray Serve handles the serving side. For LLM inference specifically, Ray Serve LLM wraps vLLM and gives you an OpenAI-compatible API with tensor parallelism, autoscaling, and request batching out of the box.
| |
Deploy it with serve run serve_config:app and hit it with any OpenAI-compatible client:
| |
For larger models that don’t fit on a single node, combine tensor and pipeline parallelism. Set tensor_parallel_size to the number of GPUs per node and pipeline_parallel_size to the number of nodes. A 70B model across 2 nodes with 4 GPUs each would use tensor_parallel_size=4, pipeline_parallel_size=2.
Ray Data: Streaming Preprocessing
Ray Data connects your raw data to your training loop with streaming execution. Instead of loading everything into memory first, it processes batches on-the-fly and keeps your GPUs fed.
| |
Inside your training function, access the data through Ray’s iterator:
| |
This keeps CPU preprocessing and GPU training running in parallel. Ray handles the backpressure automatically, so slow preprocessing doesn’t cause GPU idle time (it buffers ahead), and fast preprocessing doesn’t blow up memory.
Common Errors and Fixes
“No available node types can fulfill resource request”
You asked for more GPUs than your cluster has. Check ray status and compare against your ScalingConfig. If you’re on a single machine with 2 GPUs, num_workers=4 with use_gpu=True will hang.
Environment variables not reaching workers
Variables set on the driver don’t propagate. Pass them explicitly:
| |
Checkpointing fails on multi-node clusters
If you see FileNotFoundError during checkpoint saving, you’re using a local path on a multi-node setup. Switch to S3, GCS, or an NFS mount in your RunConfig.
OOM on workers
Each worker loads a full copy of the model. A 7B model needs ~14GB in fp16 per worker. If you’re running 4 workers on a node with 4x 16GB GPUs, there’s almost no room for activations. Use gradient checkpointing or reduce num_workers.
Placement group conflicts
If tasks inside a Ray Train worker launch nested tasks, you’ll see scheduling hangs. Add scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=None) to break out of the parent placement group.
When to Use Ray vs. Raw PyTorch
Use raw torch.distributed when you have a single multi-GPU machine, a simple training loop, and no serving requirements. It’s less abstraction and fewer dependencies.
Use Ray when you need any of these: multi-node training, fault tolerance (Ray restarts failed workers), a unified pipeline from data loading through training to serving, or autoscaling inference deployments. Ray’s overhead on a single machine is minimal, but the real payoff comes when you’re coordinating across nodes or when you want training and serving to share the same cluster.
Related Guides
- How to Set Up Multi-GPU Training with PyTorch
- How to Build a Model Training Checkpoint Pipeline with PyTorch
- How to Build a Multi-Node Training Pipeline with Fabric and NCCL
- How to Build a Model Serving Cluster with Ray Serve and Docker
- How to Speed Up Training with Mixed Precision and PyTorch AMP
- How to Build a Model Inference Cache with Redis and Semantic Hashing
- How to Use PyTorch FlexAttention for Fast LLM Inference
- How to Build a Model Training Queue with Redis and Worker Pools
- How to Speed Up LLM Inference with Speculative Decoding
- How to Build a Model Training Pipeline with Lightning Fabric