Pytorch distributed data parallel step by step

3 minute read


How to speed up your training? How to train the large model that can not fit into a singe GPU memory? How to make full use of a number of GPUs?

Distributed training is born for handling these problems. In Pytorch, there is dataparallel and distributed data parallel,


The dataparallel split a batch of data to several mini-batches, and feed each mini-batch to one GPU, each GPU has a copy of model, After each forward pass, all gradients are send to the master GPU, and only the master GPU do the back-propagation and update parameters, then it broadcast the updated parameters to other GPUs. There is three key problems with dataparallel:

  • There are twice data transaction between GPUs, one is the gradient transaction, the other is model parameter transaction. It leads to great communication overhands;
  • The memory cost is bounded by the master GPU’s memory. Because all back-propagation are performed on the master GPU, the memory cost of master GPU is larger than that of others. As a result, you can not make full use of other GPU memory since it is bounded by the master one;
  • Back-propagation on a single GPU makes training really slow.

Distributed Data Parallel (DDP)

Distributed Data Parallel aims to solve the above problems. It add a autograd hook for each parameter, so when the gradient in all GPUs is ready, it tiger the hook to synchronize gradient between GPUs by using the AllReduce function of the back-end. So after the forward pass and all gradients are synchronized, each GPU do back-propagation locally. Here, the commutation cost is only the gradient synchronization, and the whole process is not relay on one master GPU, thus all GPUs have similar memory cost. In addition, DDP can also works on multiple machines, it can communicated by P2P. For more details refer PyTorch Distributed Overview. DDP also has a benefit that it can use multiple CPUs since it run several process, which reduce the limit of python GIL.

The implementation of Dataparallel is just single line of code, you can refer the pytroch documentation for detail. Here, I only show how to use DDP on single machine with multiple GPUs.

Get start with DDP


torch.distributed.launch will spawn multiple processes for you. nproc_per_node usually set as the number GPU on the node so that each GPU has a process.

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 $args

Prepare data

In supervised learning, you can use DistributedSampler as sampler function of your dataloader. It will do the split data set job for you.

train_sampler =
train_loader =, batch_size=..., sampler=train_sampler)

In reinforcement learning, you may run your environment within every rank process with different seeds.

DDP initialization with Nvidia NCCL back-end

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1)
local_rank = parser.parse_args().local_rank

dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
world_size = dist.get_world_size()
print('my rank={} local_rank={}'.format(rank, local_rank))


Just warped by DDP

model =
model = DDP(model, device_ids=[local_rank], output_device=local_rank)


for epoch in range(num_epochs):
    for data, label in trainloader:
        prediction = model(data)
        loss = loss_fn(prediction, label)
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

Log data

use reduce torch.distributed.reduce to sum all data from different rank, the divide by world size to get mean.

loss = loss.clone().detach()
loss_mean = dist.reduce(loss, rank=0) / dist.get_world_size()
if dist.get_rank() == 0:
	# collect results into rank0
	print(f"epoch: {epoch}, loss: {loss_mean} ")

Checkpoint load and save

when loading, make sure you map location properly.

def load_checkpoint_path(model, optimizer, rank, checkpoint_path)
	# configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint_state = torch.load(checkpoint_path, map_location=map_location)
    iter_init = checkpoint_state['iter_no'] + 1  # next iteration
    return iter_init
if dist.get_rank() == 0:
	# only save on rank 0
    checkpoint_state = {
                'iter_no': iter_no,  # last completed iteration
                'model': modules.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, checkpoint_path)


To get same calculation results as single card, DDP should synchronize between GPUs when doing batchnorm.

batch norm use convert_sync_batchnorm before wrapping Network with DDP.

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

The problems you may face

  • program hang when using reduce on part of GPUs;
  • NCCL error when using docker,
  • parameter not ready when you have parameters are not used to calculate loss

will talk about these later~