Collective Communication in Distributed Systems with PyTorch
The full code for this article is on GitHub
Today, we will explore the use of PyTorch's distributed collective communication feature. When working with multiple GPUs, it is necessary to share tensors across them, which is where torch.distributed
comes in. It provides a set of APIs to send and receive tensors among multiple workers. Typically, a worker is a process responsible for a specific GPU.
When training neural networks on multiple GPUs, it is often necessary to collect tensors across multiple GPUs/processors. A common example is when computing gradients across multiple GPUs to perform an optimizer step. In this case, we must first send all gradients to all GPUs, average them, and then perform a local optimizer step. This is precisely what the all reduce
function does, but more on that later.
We will provide figures and code examples for each of the six collection strategies in torch.dist
: reduce
, all reduce
, scatter
, gather
, all gather
and broadcast
.
Setup
We tested the code with python=3.9
and torch=1.13.1
.
Before we see each collection strategy, we need to setup our multi processes code. Copied from PyTorch doc, we have to first write a function to initialize the process that will run this file.
import os
from typing import Callable
import torch
import torch.distributed as dist
def init_process(rank: int, size: int, fn: Callable[[int, int], None], backend="gloo"):
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
the call to dist.init_process_group
is where the magic happens.
Then, we can use torch.multiprocessing
to spawn the process, in our case 4
.
...
import torch.multiprocessing as mp
def func(rank: int, size: int):
# each process will call this function
continue
if __name__ == "__main__":
size = 4
processes = []
mp.set_start_method("spawn")
for rank in range(size):
p = mp.Process(target=init_process, args=(rank, size, func))
p.start()
processes.append(p)
for p in processes:
p.join()
The code creates and runs a number of separate processes, using the torch.multiprocessing
module. mp.set_start_method("spawn")
is used to set the start method of the processes to spawn, which allows new processes to be created. Each spawned mp.Process
calls init_process
and then a user specified function func
with rank
and size
as parameter.
Let's create a simple example,
...
def hello_world(rank: int, size: int):
print(f"[{rank}] say hi!")
...
for rank in range(size):
# passing `hello_world`
p = mp.Process(target=init_process, args=(rank, size, hello_world))
...
Will output
[0] say hi!
[2] say hi!
[3] say hi!
[1] say hi!
Reduce
The reduce
operation in torch.distributed
is used to combine tensors from multiple GPUs or processes into a single tensor on one of the GPUs or processes. The reduce operation applies a specified reduction operation (e.g. sum, product, max) element-wise to the input tensors and returns the result on a single GPU or process, known as the root rank. The root rank is specified as an argument when calling the reduce
function.
In the figure, we have 4
process and we reduce all the tensors to rank0
. Each process has a tensor with 1
.
def do_reduce(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
tensor = torch.ones(1)
# sending all tensors to rank 0 and sum them
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM, group=group)
# can be dist.ReduceOp.PRODUCT, dist.ReduceOp.MAX, dist.ReduceOp.MIN
# only rank 0 will have four
print(f"[{rank}] data = {tensor[0]}")
outputs:
[1] data = 3.0
[2] data = 2.0
[0] data = 4.0
[3] data = 1.0
Only rank0
has the reduced value (4
).
The op
to dist.reduce
argument specifies the reduction operation, which in this case is set totorch.distributed.ReduceOp.SUM
, other possibilities are PRODUCT
, MAX
and MIN
.
It's important to notice that the reduce operation is done between the tensors on different devices, the root rank need to have a copy of the tensor on the local memory to perform the operation on it, in our case tensor
.
All Reduce
The all_reduce
operation in torch.distributed
is similar to the reduce operation, but instead of returning the result on a single GPU or process, it returns the result on all GPUs or processes.
Like the reduce operation, all_reduce
applies a specified reduction operation (e.g. sum, product, max) element-wise to the input tensors and returns the result on all GPUs or processes. This allows for easy computation of values such as the mean or sum of a tensor across all GPUs or process.
The code,
def do_all_reduce(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
# can be dist.ReduceOp.PRODUCT, dist.ReduceOp.MAX, dist.ReduceOp.MIN
# will output 4 for all ranks
print(f"[{rank}] data = {tensor[0]}")
outputs:
[3] data = 4.0
[0] data = 4.0
[1] data = 4.0
[2] data = 4.0
Each process has the reduced value (4
).
Scatter
The scatter
operation in torch.distributed
is used to divide a tensor on one GPU or process, known as the root rank, and send a portion of it to each other GPU or process. The root rank is specified as an argument when calling the scatter function.
The code,
def do_scatter(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
tensor = torch.empty(1)
# sending all tensors from rank 0 to the others
if rank == 0:
tensor_list = [torch.tensor([i + 1], dtype=torch.float32) for i in range(size)]
# tensor_list = [tensor(1), tensor(2), tensor(3), tensor(4)]
dist.scatter(tensor, scatter_list=tensor_list, src=0, group=group)
else:
dist.scatter(tensor, scatter_list=[], src=0, group=group)
# each rank will have a tensor with their rank number
print(f"[{rank}] data = {tensor[0]}")
outputs:
[2] data = 3.0
[1] data = 2.0
[3] data = 4.0
[0] data = 1.0
The scatter
function is useful when you want to split a large tensor into smaller chunks and distribute them across multiple GPUs or processes. For example, you can use it to split a large batch of data into smaller chunks for parallel processing.
Gather
The gather
operation in torch.distributed
is used to collect tensors from multiple GPUs or processes and concatenate them into a single tensor on one of the GPUs or processes, known as the root rank. The root rank is specified as an argument when calling the gather function.
The code,
def do_gather(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
tensor = torch.tensor([rank], dtype=torch.float32)
# sending all tensors from rank 0 to the others
if rank == 0:
# create an empty list we will use to hold the gathered values
tensor_list = [torch.empty(1) for i in range(size)]
dist.gather(tensor, gather_list=tensor_list, dst=0, group=group)
else:
dist.gather(tensor, gather_list=[], dst=0, group=group)
# only rank 0 will have the tensors from the other processed
# [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
if rank == 0:
print(f"[{rank}] data = {tensor_list}")
outputs:
[0] data = [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
rank0
has the gathered tensors.
Also, the gather
operation requires that the shape of the tensors on all the devices are the same and the root rank need to have a copy of the tensor on the local memory to perform the operation on it.
It's useful when you want to have all the information on one device or process to perform some kind of operation or analysis on it.
All Gather
The all_gather operation in torch.distributed
is similar to the gather operation, but instead of returning the concatenated tensor on a single GPU or process, it returns the concatenated tensor on all GPUs or processes.
Like the gather operation, all_gather
collects tensors from multiple GPUs or processes and concatenate them into a single tensor on all GPUs or processes. This allows for easy collection of information from all GPUs or processes.
The code,
def do_all_gather(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
tensor = torch.tensor([rank], dtype=torch.float32)
# create an empty list we will use to hold the gathered values
tensor_list = [torch.empty(1) for i in range(size)]
# sending all tensors to the others
dist.all_gather(tensor_list, tensor, group=group)
# all ranks will have [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
print(f"[{rank}] data = {tensor_list}")
outputs,
[0] data = [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
[2] data = [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
[3] data = [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
[1] data = [tensor([0.]), tensor([1.]), tensor([2.]), tensor([3.])]
Each process has the gathered data. The all_gather
operation requires that the shape of the tensors on all the devices are the same.
It's useful when you want to have a copy of the information on all the devices or processes, maybe to synchronize something.
Broadcast
The broadcast
operation in torch.distributed
is used to send a tensor from one GPU or process, known as the root rank, to all other GPUs or processes. The root rank is specified as an argument when calling the broadcast function.
The code,
def do_broadcast(rank: int, size: int):
# create a group with all processors
group = dist.new_group(list(range(size)))
if rank == 0:
tensor = torch.tensor([rank], dtype=torch.float32)
else:
tensor = torch.empty(1)
# sending all tensors to the others
dist.broadcast(tensor, src=0, group=group)
# all ranks will have tensor([0.]) from rank 0
print(f"[{rank}] data = {tensor}")
outputs,
[2] data = tensor([0.])
[1] data = tensor([0.])
[3] data = tensor([0.])
[0] data = tensor([0.])
Each process has the rank0
tensor. The broadcast
function is useful when you want to share the same information or values across multiple GPUs or processes. For example, you can use it to share the model's parameters or shared optimizer states across multiple GPUs or processes.
Conclusion
In conclusion, PyTorch's distributed collective communication feature provides a powerful set of tools for working with multiple GPUs. The six collection strategies we have discussed - reduce
, all reduce
, scatter
, gather
, all gather
and broadcast
- offer a range of options for sharing tensors across workers. Each strategy has its own use case and can be tailored to the specific needs of your project.
By mastering these strategies, you can effectively train neural networks on multiple GPUs, improving performance and scalability. With this knowledge, you can take full advantage of the power of distributed computing and unlock the full potential of your models.
Thanks for reading :)