world of deep learning training, the role of the ML developer can be likened to that of the conductor of an orchestra. Just as a conductor must time the entry of each instrument to produce the perfect harmony, so must the ML practitioner orchestrate a multitude of hardware components — CPUs and GPUs with their associated memory, high-speed storage, network controllers, various communication buses, etc. — to work together seamlessly to maximize runtime performance. Just as a single off-key note can disrupt an entire musical production, a bottleneck or inefficiency in any one of these components can severely hamper the overall training process.

In this complex landscape, it’s of critical importance that you have an intimate understanding of your system’s underlying topology and that you know how to apply it toward optimal runtime performance. In a previous post, we explored the critical role of topology awareness in a distributed training setting and discussed the advantage of topology-aware gradient sharing algorithms in minimizing cross-node communication and boosting performance.

In this post, the tenth in our series on PyTorch model analysis and optimization, we zoom in on the collaboration between the CPU and GPU in training and running AI/ML models. In a typical training pipeline, the CPU is responsible for preparing and pre-processing data, for loading GPU kernels, and for processing output, while the GPU is responsible for the model execution. This cooperation isn’t merely a hand-off — it’s a constant, high-speed exchange of data and commands, in what can be likened to an intricate dance — where precision timing and physical proximity are crucial. For this dance to be performed optimally, it must be choreographed in a manner that accounts for the underlying system topology. In particular, it must take into account the system’s Non-Uniform Memory Access (NUMA) architecture.

NUMA Architecture

The NUMA architecture is designed to optimize memory transactions by associating local memory banks directly with specific CPU sockets. Most modern multi-GPU High-Performance Computing (HPC) systems consist of two or more NUMA nodes, where CPUs and GPUs are divided into disjoint groups, each attached to one node. NUMA is most efficient when memory banks are accessed from within the same node. Accessing memory on a remote node requires data traversal over a dedicated NUMA interconnect, which is significantly slower than accessing local memory. In memory-intensive applications like AI/ML workloads, cross-NUMA memory accesses can introduce performance bottlenecks.

Unfortunately, popular AI/ML frameworks — most notably PyTorch — do not account for NUMA architecture by default. However, as we will demonstrate in this post, you can introduce NUMA-awareness into your PyTorch script without much difficulty.

In the next section, we will explore the NUMA architecture of the popular Amazon EC2 p4d.96xlarge instance (containing 8 NVIDIA A100 GPUs and 96 vCPUs) running a PyTorch (2.6) Deep Learning AMI (DLAMI). We will then demonstrate how to implement a NUMA-aware PyTorch script and evaluate its impact on runtime performance.

Disclaimers

The NUMA architecture is a complex and nuanced topic. In this post, we explore just one of its implications: its impact on deep learning. For more comprehensive details on the topic, please refer to other authoritative resources.

The code we will share is intended for demonstrative purposes and should not be relied on for correctness or optimality. Please do not interpret our choice of platform, framework, or any other tool or library as an endorsement for its use.

NUMA Architecture Discovery

There are multiple ways to detect the NUMA architecture of the system you are running on. In this section, we will demonstrate how to explore the NUMA layout of an Amazon EC2 p4d.96xlarge instance using commonly available Linux command-line tools.

CPU NUMA Node Discovery

The lscpu command provides information about the CPU architecture of a Linux system, including a section describing the NUMA layout. By running the command on an Amazon EC2 p4d.96xlarge instance we learn that it consists of 96 vCPUs divided between two NUMA nodes:

NUMA:                     
  NUMA node(s):           2
  NUMA node0 CPU(s):      0-23,48-71
  NUMA node1 CPU(s):      24-47,72-95

GPU NUMA Node Discovery

To determine which NUMA node each GPU is attached to, we use a two-step process: First, we identify the PCI ID associated with each GPU, and then we look up the NUMA node associated with that PCI ID.

The PCI ID is one of the properties of the GPUs reported by the nvidia-smi utility. In the following snippet, we see the PCI Bus IDs of the first two out of the eight GPUs on our Amazon EC2 p4d.96xlarge instance:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20             Driver Version: 570.133.20     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:10:1C.0 Off |                    0 |
| N/A   48C    P0             57W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:10:1D.0 Off |                    0 |
| N/A   45C    P0             56W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

Next, we use these PCI IDs to determine the corresponding NUMA node by reading from the /sys/bus/pci/devices/ path:

ubuntu@XX:~$ cat /sys/bus/pci/devices/0000:10:1c.0/numa_node
0
ubuntu@XX:~$ cat /sys/bus/pci/devices/0000:10:1d.0/numa_node
0

This indicates that GPUs 0 and 1 are connected to NUMA node 0.

Additional Tools

An alternative method for finding the NUMA node assignment of the PCI IDs is using lstopo — a command-line utility that reports the topology of a computer system. Though it isn’t included by default in the DLAMI, it can be easily installed by running:

sudo apt install hwloc

Here is a small segment of its command-line output which reports four PCI IDs on NUMA node 0. These are marked with “(3D)” tags—common identifiers of 3D accelerators, otherwise known as GPUs.

Machine (1122GB total)
  Package L#0
    NUMANode L#0 (P#0 561GB)
    HostBridge
      2 x { PCI 10:1c.0-1d.0 (3D) }
    HostBridge
      2 x { PCI 20:1c.0-1d.0 (3D) }

Another useful tool is numactl — a command-line utility in Linux inspecting and managing NUMA policies. To install numactl, run:

sudo apt install numactl

You can inspect the NUMA configuration by running:

numactl --hardware

On our Amazon EC2 p4d.96xlarge instance this produces the following output:

available: 2 nodes (0-1)
node 0 cpus: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
node 0 size: 574309 MB
node 0 free: 572012 MB
node 1 cpus: 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
node 1 size: 574411 MB
node 1 free: 572420 MB
node distances:
node   0   1 
  0:  10  21 
  1:  21  10

This provides useful information such as memory sizes and CPU assignments per NUMA node, as well as inter-node memory access costs (higher numbers = greater latency).

NUMA Topology Summary

To summarize the topology we’ve discovered, here is a Python representation of the CPU and GPU layout:

cpus_per_numa_node = [
    list(range(0, 24)) + list(range(48, 72)), # NUMA node 0
    list(range(24, 48)) + list(range(72, 96)) # NUMA node 1
]

gpus_per_numa_node = [
    [0, 1, 2, 3], # NUMA node 0
    [4, 5, 6, 7]  # NUMA node 1
]

We will use this later to implement NUMA-aware training.

The Impact of NUMA Placement on Data Loading

Memory transactions between CPU and GPU occur at various stages during model execution — for example, when offloading tensors to CPU memory, or when executing certain model components (e.g., sequential algorithms such as non-maximum suppression) on the CPU. In this post, we’ll focus on the transfer of input data from the CPU to the GPU — a critical part of every AI/ML workflow.

The CPU Processes in a Typical Distributed Training Job

In a typical distributed training setting, new CPU processes are created on two occasions:

  • At startup: A separate training process is created for each GPU. These processes handle model setup and training execution on their assigned GPUs. In the script we’ll introduce later, these are launched via torch.multiprocessing.spawn.
  • Per dataloader: Each training process creates its own DataLoader instance to provide data batches for its GPU. Each dataloader typically creates multiple worker processes, which generate individual training samples. These samples are then grouped by the main process into batches.

In the case of our Amazon EC2 p4d.96xlarge instance, each of these processes is assigned to a CPU, which resides on one of the two NUMA nodes.

Why NUMA Placement Matters

Ideally, the main training process for a given GPU — and all of its associated dataloader worker processes — will be located on the same NUMA node as the GPU. Otherwise, we may end up seeing a considerable amount of traffic on the NUMA interconnects, which could result in performance bottlenecks.

Let’s imagine a particularly bad setup:

  • GPU i is located on NUMA node 0.
  • The main training process assigned to GPU i is scheduled on a CPU on NUMA node 1.
  • The worker processes spawned by the training process are all assigned to CPUs on NUMA node 0.

This results in the following inefficient sequence:

  1. Individual samples are created on NUMA node 0.
  2. The samples are transmitted through the interconnect to the main process on node 1, where they are grouped together into a training batch.
  3. The batch is sent back across the interconnect to node 0, where it is fed to the GPU.

Sounds horrendous, right?

While this exact scenario may be rare, it illustrates how the default Linux scheduler — if left unmanaged — can result in inefficient placement and redundant traffic over the NUMA interconnect. And with the high cost of GPU training, relying on the “luck of the scheduler” is not recommended.

When NUMA Placement Matters Most

The performance impact of poor NUMA placement depends heavily on the workload characteristics. Specifically, training steps that consist of a large number of large data transactions will suffer more than training steps with few transactions and small data sizes.

When it comes to dataloading, the impact of inefficient NUMA Placement will also depend on the size of the model. Recall that AI/ML workloads are designed to run the dataloading on the CPU in parallel with model execution on the GPU. Thus, if the GPU execution takes significantly longer than the dataloading, inefficient NUMA placement might go unnoticed. But if dataloading time is similar to or longer than GPU execution time — or if you’re already experiencing GPU starvation — the impact can be significant.

Benchmark Impact of NUMA Pinning

Because the effect of NUMA-aware pinning can vary widely, it’s essential to benchmark its impact on a per-workload basis.

In some situations, NUMA pinning could even hurt performance. For instance, in systems where CPUs on one NUMA node are designated for other tasks, or systems where one NUMA node contains CPUs bit no GPUs, NUMA pinning could limit access to CPU power, ultimately straining throughput performance.

A Toy PyTorch Experiment

To demonstrate the impact of NUMA awareness on runtime performance, we design a toy distributed training experiment. Our baseline implementation simply reports the NUMA assignment of each spawned process. We then apply NUMA-based CPU and memory affinity and measure the impact on throughput.

NUMA Discovery and Pinning Utilities

We begin by defining utility functions for NUMA node discovery and pinning. The implementation shown here uses the hardcoded NUMA topology we summarized earlier. A more robust version would dynamically discover topology by parsing the output of system utilities such as lscpu and nvidia-smi.

The following code block contains utilities for looking up NUMA placement. For each process we report both the NUMA node of the host CPU and the NUMA node of its allocated memory it is bound to. We use numactl --show to detect the memory binding of the current process.

import os, re, psutil, ctypes, subprocess

# Discover NUMA node of process
def discover_cpu_numa_placement():
    cpu_id = psutil.Process().cpu_num()
    for node in range(len(cpus_per_numa_node)):
        if cpu_id in cpus_per_numa_node[node]:
            return node


# Discover NUMA node of GPU
def discover_gpu_numa_placement(rank):
    for node in range(len(gpus_per_numa_node)):
        if rank in gpus_per_numa_node[node]:
            return node


# Use numactl to get mememory binding of CPU process
def get_membinding():
    result = subprocess.run(['numactl', '--show'],
                            check=True,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE,
                            text=True)
    output = result.stdout
    match = re.search(r"membind:s*([0-9s]+)", output)
    nodes = [int(n) for n in match.group(1).split()]
    return nodes

# Detect NUMA placement of process
def get_numa_placement(rank):
    cpu_node = discover_cpu_numa_placement()
    gpu_node = discover_gpu_numa_placement(rank)
    m_bind = get_membinding()
    node_match = cpu_node == gpu_node
    status = f"GPU node: {gpu_node}n" 
             f"CPU node: {cpu_node}n" 
             f"mem binding {m_bind[0] if len(m_bind)==1 else m_bind}n"
    if not node_match:
        status += "GPU and CPU NUMA nodes do NOT matchn"
    return status

One common method for setting CPU affinity in Python is via the os.sched_setaffinity function. However, this method is insufficient for our purposes because it only pins the CPU — it does not bind the memory it uses. To bind both CPU and memory binding we use the numa_bind function from the libnuma library. (Run sudo apt install libnuma-dev to install).

# Set process affinity by NUMA node ID
def set_affinity_by_node(node):
    pid = os.getpid()
    target_cpus = cpus_per_numa_node[node]
    os.sched_setaffinity(pid, target_cpus)


# Bind a process and memory to given NUMA node
def numa_bind(node):
    libnuma = ctypes.CDLL("libnuma.so")
    libnuma.numa_allocate_nodemask.restype = ctypes.c_void_p
    libnuma.numa_bitmask_clearall.argtypes = [ctypes.c_void_p]
    libnuma.numa_bitmask_setbit.argtypes = [ctypes.c_void_p, ctypes.c_uint]
    libnuma.numa_bind.argtypes = [ctypes.c_void_p]

    nodemask_ptr = libnuma.numa_allocate_nodemask()
    libnuma.numa_bitmask_clearall(nodemask_ptr)
    libnuma.numa_bitmask_setbit(nodemask_ptr, node)
    libnuma.numa_bind(nodemask_ptr)

Model Definition

Next, we define a simple distributed training script using a ResNet-18 image classification model and a synthetic dataset. Each synthetic sample is a randomly generated 1024×1024 image, simulating large memory transactions. On the GPU, images are downscaled to 224×224 before being passed to the model. This setup results in a bottleneck in the input data pipeline. The bottleneck can be detected by comparing throughput (in steps per second) during normal training versus when running on a cached batch. For more on identifying dataloader bottlenecks, see our earlier posts (e.g., here and here).

Each time a new process is started, it reports its NUMA assignment using the utilities we defined above. For the dataloader workers this is done using a custom worker_init_fn function. We include a numa_aware control flag that determines whether to apply NUMA pinning.

It’s important to note that when applying NUMA-binding using numa_bind inside a processthe CPU-binding is not always inherited by subprocesses. It is therefore essential to reapply NUMA binding explicitly within the dataloader workers.

import time
import torch
from functools import partial
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from torchvision.transforms import Resize


# A synthetic dataset with random images and labels
class FakeDataset(Dataset):
    def __init__(self, n_items):
        super().__init__()
        self.n_items = n_items

    def __len__(self):
        return self.n_items

    def __getitem__(self, index):
        rand_image = torch.randn([3, 1024, 1024], dtype=torch.float32)
        label = torch.tensor(data=index % 1000, dtype=torch.int64)
        return rand_image, label


# Callback for DataLoader workers to detect their NUMA placement.
def worker_init_fn(worker_id, rank=0, bind_to_node=None):
    if bind_to_node is not None:
        numa_bind(bind_to_node)
    print(f'GPU {rank} worker {worker_id} NUMA properties:n'
          f'{get_numa_placement(rank)}')

# standard training loop
def train(
        local_rank,
        world_size,
        numa_aware=False
):
    bind_to_node = None
    if numa_aware:
        bind_to_node = discover_gpu_numa_placement(local_rank)
        numa_bind(bind_to_node)

    print(f'GPU {local_rank} training process NUMA properties:n'
          f'{get_numa_placement(local_rank)}')

    torch.cuda.set_device(local_rank)

    # DDP setup
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(2222)
    dist.init_process_group('nccl', rank=local_rank,
                            world_size=world_size)

    device = torch.cuda.current_device()
    model = DDP(resnet18().to(device), [local_rank])
    transform = Resize(224)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters())

    # num steps
    warmup = 10
    active = 100
    total_steps = warmup + active

    # distribute evenly across GPUs
    num_workers = os.cpu_count() // world_size
    batch_size = 128
    data_loader = DataLoader(
        FakeDataset(total_steps * batch_size),
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=partial(
            worker_init_fn,
            rank=local_rank,
            bind_to_node=bind_to_node
        )
    )

    for idx, (inputs, target) in enumerate(data_loader, start=1):
        inputs = inputs.to(device, non_blocking=True)
        targets = target.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(transform(inputs))
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        if idx == warmup:
            torch.cuda.synchronize()
            t0 = time.perf_counter()
        elif idx == total_steps:
            break

    if local_rank == 0:
        torch.cuda.synchronize()
        total_time = time.perf_counter() - t0
        print(f'average step time: {total_time / active}')
        print(f'average throughput: {active / total_time}')

    dist.destroy_process_group()


if __name__ == '__main__':
    bind2gpu = False

    if os.environ.get("LOCAL_RANK", None):
        # initialized with torchrun or bash script
        local_rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        train(local_rank, world_size, bind2gpu)
    else:
        world_size = torch.cuda.device_count()
        torch.multiprocessing.spawn(
            fn=train,
            args=(world_size, bind2gpu),
            nprocs=world_size,
            join=True
        )

Observing NUMA Placement

Here is a sample output from running the script on a single GPU with four dataloader workers and no NUMA binding. In this run, all processes were scheduled on NUMA node 1, while the GPU resides on NUMA node 0:

GPU 0 training process NUMA properties:
GPU node: 0
CPU node: 1
mem binding [0, 1]
GPU and CPU NUMA nodes do NOT match

GPU 0 worker 1 NUMA properties:
GPU node: 0
CPU node: 1
mem binding [0, 1]
GPU and CPU NUMA nodes do NOT match

GPU 0 worker 3 NUMA properties:
GPU node: 0
CPU node: 1
mem binding [0, 1]
GPU and CPU NUMA nodes do NOT match

GPU 0 worker 0 NUMA properties:
GPU node: 0
CPU node: 1
mem binding [0, 1]
GPU and CPU NUMA nodes do NOT match

GPU 0 worker 2 NUMA properties:
GPU node: 0
CPU node: 1
mem binding [0, 1]
GPU and CPU NUMA nodes do NOT match

Baseline Results

NUMA placement can vary between runs, so we repeated the baseline experiment ten times. The resultant average throughput was 1.04 steps per second.

NUMA-Aware Training

To enable NUMA-aware training, we set the numa_aware flag to True. This causes each training process to run on a CPU from the same NUMA node as its assigned GPU and allocate memory on that same NUMA node. This configuration ensures NUMA-locality across CPU, memory, and GPU, reducing the traffic over the NUMA interconnect.

The average throughput in this setting increased to 1.24 steps per second — a 19% improvement over the baseline experiment.

CPU Binding with numactl

An alternative approach to NUMA pinning is to launch each training process from the command-line via the numactl command. The advantage of this method is that the binding is applied before the process is started rather than on entry. This avoids the possibility of early memory allocations on the wrong node before pinning. Another advantage is that the NUMA placement is inherited by subprocesses, making it unnecessary to re-pin dataloader workers manually. Note, that the inheritance behavior may vary between systems, so you should confirm it on your specific setup before relying on it.

One downside of this method is that it cannot be easily integrated with PyTorch’s launch utilities like torch.multiprocessing.spawn or torchrun. If your code depends on those utilities, you may need to replicate some of their logic manually. Furthermore, some high-level frameworks (e.g., Lightning) may not expose control over process initialization, preventing the use of binding via numactl.

Here’s a sample Bash script that wraps our training script with NUMA pinning using numactl:

#!/bin/bash

# Define GPU-to-NUMA mapping
GPU_LIST=(0 1 2 3 4 5 6 7)
GPU_TO_NUMA=(0 0 0 0 1 1 1 1)

NUM_GPUS=${#GPU_LIST[@]}
WORLD_SIZE=$NUM_GPUS

for i in "${!GPU_LIST[@]}"; do
    GPU_ID=${GPU_LIST[$i]}
    NUMA_NODE=${GPU_TO_NUMA[$i]}
    LOCAL_RANK=$i

    echo "Launch GPU $LOCAL_RANK on NUMA $NUMA_NODE" >&1

    numactl --cpunodebind=$NUMA_NODE --membind=$NUMA_NODE 
    env 
        LOCAL_RANK=$LOCAL_RANK 
        WORLD_SIZE=$WORLD_SIZE 
    python train.py &

done

wait

Results:

The table below summarizes the results of our experiments.

Experiment Results (by Author)

In this toy example, the benefits of NUMA-aware training are clear. However, as noted earlier, the actual impact can vary depending on your model architecture, data loading characteristics, and system configuration.

Summary

In our constant pursuit of AI/ML workload optimization, topology awareness — including NUMA node placement — is critical.

In this post, we continued our exploration of PyTorch model profiling and optimization by demonstrating how NUMA pinning can improve throughput performance. We hope you will find this method useful in your own AI/ML projects.

For more tips, tricks, and techniques for optimizing PyTorch model development, be sure to check out the other posts in this series.

Share.

Comments are closed.