in the data input pipeline of a machine learning model running on a GPU can be particularly frustrating. In most workloads, the host (CPU) and the device (GPU) work in tandem: the CPU is responsible for preparing and feeding data, while the GPU handles the heavy lifting — executing the model, performing backpropagation during training, and updating weights.

In an ideal situation, we want the GPU — the most expensive component of our AI/ML infrastructure — to be highly utilized. This leads to faster development cycles, lower training costs, and reduced latency in deployment. To achieve this, the GPU must be continuously fed with input data. In particular, we would like to prevent the onset of “GPU starvation” — a situation in which our most expensive resource lays idle while it waits for input data. Unfortunately, “GPU starvation” due to bottlenecks in the data input pipeline is quite common and can dramatically reduce system efficiency. As such, it’s important for AI/ML developers to have reliable tools and strategies for diagnosing and addressing such issues.

This post — the eighth in our series on the topic of PyTorch Model Performance Analysis and Optimization — introduces a simple caching strategy for identifying bottlenecks in the data input pipeline. As in earlier posts, we aim to reinforce two key ideas:

  1. AI/ML developers must take responsibility for the runtime performance of their models.
  2. You do not need to be a CUDA or systems expert to implement significant performance optimizations.

We’ll start by outlining some of the common causes of GPU starvation. Then we’ll introduce our caching-based strategy for identifying and analyzing input pipeline performance issues. We’ll close by reviewing a set of practical tools, tricks, and techniques (TTTs) for overcoming performance bottlenecks in the data input pipeline.

To facilitate our discussion we will define a toy PyTorch model and an associated data input pipeline. The code that we will share is intended for demonstrative purposes — please do not rely on its correctness or optimality. Furthermore, please do not our mention of any tool, or technique as an endorsement of its use.

A Toy PyTorch Model

We define a simple PyTorch-based image classification model model:

undefined

We define a synthetic dataset with a number of transformations — intentionally designed to include a severe input pipeline bottleneck. For more details on the dataset definition please see this post.

import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms as T

class FakeDataset(VisionDataset):
    def __init__(self, transform):
        super().__init__(root=None, transform=transform)
        self.size = 10000

    def __getitem__(self, index):
        # create a random 1024x1024 image
        img = Image.fromarray(np.random.randint(
            low=0,
            high=256,
            size=(input_img_size[0], input_img_size[1], 3),
            dtype=np.uint8
        ))
        # create a random label
        target = np.random.randint(low=0, high=num_classes, 
                                   dtype=np.uint8).item()
        # Apply tranformations
        img = self.transform(img)
        return img, target

    def __len__(self):
        return self.size

class RandomMask(torch.nn.Module):
    def __init__(self, ratio=0.25):
        super().__init__()
        self.ratio=ratio

    def dilate_mask(self, mask):
        # perform 4 neighbor dilation on mask
        from scipy.signal import convolve2d
        dilated = convolve2d(mask, [[0, 1, 0],
                                    [1, 1, 1],
                                    [0, 1, 0]], mode='same').astype(bool)
        return dilated

    def forward(self, img):
        mask = np.random.uniform(size=(img_size, img_size)) < self.ratio
        dilated_mask = torch.unsqueeze(torch.tensor(self.dilate_mask(mask)),0)
        dilated_mask = dilated_mask.expand(3,-1,-1)
        img[dilated_mask] = 0.
        return img

class ConvertColor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.A=torch.tensor(
            [[0.299, 0.587, 0.114],
             [-0.16874, -0.33126, 0.5],
             [0.5, -0.41869, -0.08131]]
        )
        self.b=torch.tensor([0.,128.,128.])

    def forward(self, img):
        img = img.to(dtype=torch.get_default_dtype())
        img = torch.matmul(self.A,img.view([3,-1])).view(img.shape)
        img = img + self.b[:,None,None]
        return img

class Scale(object):
    def __call__(self, img):
        return img.to(dtype=torch.get_default_dtype()).div(255)

transform = T.Compose(
    [T.PILToTensor(),
     T.RandomCrop(img_size),
     RandomMask(),
     ConvertColor(),
     Scale()])

train_set = FakeDataset(transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256,
                                           num_workers=4, pin_memory=True)

Next, we define the model, loss function, optimizer, training step, and training loop, which we wrap with a PyTorch Profiler context manager to capture performance data.

from statistics import mean, variance
from time import time

device = torch.device("cuda:0")
model = Net().cuda(device)
criterion = nn.CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train_step(model, criterion, optimizer, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


model.train()

t0 = time()
times = []

with torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=10, warmup=2, active=10, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/prof'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    for step, data in enumerate(train_loader):
        # copy data to device
        inputs = data[0].to(device=device, non_blocking=True)
        labels = data[1].to(device=device, non_blocking=True)

        # run train step
        train_step(model, criterion, optimizer, inputs, labels)
        prof.step()
        times.append(time()-t0)
        t0 = time()
        if step >= 100:
            break

print(f'average time: {mean(times[1:])}, variance: {variance(times[1:])}')

For our experiments, we use an Amazon EC2 g5.xlarge instance (containing an NVIDIA A10G GPU and 4 vCPUs) running a PyTorch (2.6) Deep Learning AMI (DLAMI). Running our toy script in this environment results in an average throughput of 0.89 steps per second, an underwhelming GPU utilization of 22%, and in the following profiling trace:

Profiling Trace of GPU Starvation (by Author)

As discussed in detail in a previous post, the profiling trace shows a clear pattern of GPU starvation — where the GPU spends most of its time waiting for data from the PyTorch DataLoader. This suggests that there is a performance bottleneck in the data input pipeline, which prevents input batches from being prepared quickly enough to keep the GPU fully occupied. Importantly, input pipeline performance issues can stem from a variety of sources. In the case of our toy example, the cause of the bottleneck is not apparent from the trace captured above.

A brief note for readers/developers that (despite all of our lecturing) remain aversive to the use of PyTorch Profiler: The data caching-based technique we will discuss below will present an alternative way of identifying GPU starvation — so do not despair.

GPU Starvation — Finding the Root Cause

In this section, we briefly review common causes of performance bottlenecks on the input data pipeline.

Recall, that in a typical model execution flow:

  1. Raw data is is loaded or streamed from storage (e.g., local RAM or disk, a remote network file system, or a cloud-based object store such as Amazon S3 or Google Cloud Storage).
  2. It is then preprocessed on the CPU.
  3. Finally, the processed data is copied to the GPU for inference or training.

Correspondingly, bottlenecks can emerge at each of the following stages:

  1. Slow data retrieval: There are multiple factors that can limit how quickly raw data can be retrieved by the CPU, including: the choice of storage backend (e.g., cloud storage vs. local SSD), the available network bandwidth, the data format, and more.
  2. CPU resource exhaustion or misuse: Preprocessing tasks — such as data augmentation, image transformations, or decompression — can be CPU-intensive. When the number or complexity of these operations exceeds the available CPU capacity, or if the CPU resources are managed inefficiently (e.g., an in-optimal choice of number of workers), a bottleneck can occur. It’s worth noting that CPUs are also responsible for other model-related duties like loading GPU kernels, memory management, metric reporting, and more.
  3. Host-to-device transfer bottlenecks: Once data is processed, it must be transferred to the GPU. This can become a bottleneck if data batches are large relative to the CPU-GPU memory bandwidth, or if the memory copying is performed inefficiently (e.g., individual samples are copied rather than full batches).

The Limitation of Performance Profilers

A common way to identify data pipeline bottlenecks is by using a performance profiler. In part 4 of this series, Solving Bottlenecks on the Data Input Pipeline with PyTorch Profiler and TensorBoard, we demonstrated how to do this using PyTorch’s built-in profiler. However, given that the input data pipeline runs on the CPU, any Python profiler could be used.

The problem with this approach is that we typically use multiple worker processes for data loading, making performance profiling particularly complex. In our previous post, we overcame this by running the data-loading and the model execution in a single process (i.e., we set the num_workers argument of the DataLoader constructor to zero). However, this is a highly intrusive configuration change that can have a significant impact on the overall performance of our model.

The caching-based method we present in this post aims to pinpoint the source of the performance bottleneck in a far less intrusive manner. In particular, it will enable us to measure the model performance without altering the multi-worker data-loading behavior.

Bottleneck Detection via Caching

In this section, we propose a multi-step approach for analyzing the performance of the input data pipeline. We’ll demonstrate how this method can be applied to our toy training workload to identify the causes of the GPU starvation.

Step 1: Cache a Batch on the Device

We begin by creating a single input batch, copying it to the GPU, and then measuring the runtime performance of the model when iterating over just that batch. This provides a theoretical upper bound on the model’s throughput — i.e., the maximum throughput achievable when the GPU is not data-starved.

In the following code block we modify the training loop of our toy script so that it runs on a single batch that is cached on the GPU:

data = next(iter(train_loader))
inputs = data[0].to(device=device, non_blocking=True)
labels = data[1].to(device=device, non_blocking=True)
t0 = time()
times = []
for step in range(100):
    train_step(model, criterion, optimizer, inputs, labels)
    times.append(time()-t0)
    t0 = time()

The resultant average throughput is 3.45 steps per second — nearly 4 times higher than our baseline result. Not only does this confirm a significant data pipeline bottleneck, but it also quantifies its impact.

Bonus Tip: Profile and Optimize with Device-Cached Data
Running a profiler on a single batch cached on the GPU isolates the model execution from the input pipeline. This helps you identify inefficiencies in the model’s raw compute path. Ideally, GPU utilization here should approach 100%. In our case, utilization is around 95%, which is acceptable.

Step 2: Cache a Batch on the Host (CPU)

Next, we cache a single input batch on the host (CPU) instead of the device. Now, each step includes both a memory copy from CPU to GPU and the model execution.

Since PyTorch’s memory pinning allows for asynchronous data transfers, we expect the host-to-device memory copy for batch N+1 to overlap with the model execution on batch N. Consequently, our expectation is that the throughput will be in the same ballpark as in the device-cached case. If not, this would be a clear indication of a bottleneck in the host to device memory copy.

The following block of code contains our application of this step to our toy model:

data = next(iter(train_loader))
t0 = time()
times = []
for step in range(100):
    inputs = data[0].to(device=device, non_blocking=True)
    labels = data[1].to(device=device, non_blocking=True)
    train_step(model, criterion, optimizer, inputs, labels)
    times.append(time()-t0)
    t0 = time()

The resultant throughput following this change is 3.33 steps per second — a minor drop from the previous result — indicating that the host-to-device transfer is not a bottleneck. We need to keep looking for the source of our performance bottleneck.

Steps 3 and on: Cache at Intermediate Stages in the Data Pipeline

We continue our search by “climbing” up the data input pipeline, caching at various intermediate points to pinpoint the bottleneck. The precise application of this process will vary based on the details of the pipeline. Suppose the pipeline can be broken into stages. If caching after stage N yields a significantly worse throughput when caching after stage N+1, we can deduce that that the inclusion of the processing of stage N+1 is what is slowing us down.

Step 3a: Cache a Single Processed Sample
In the code block below, we modify our dataset to cache one fully processed sample. This simulates a pipeline that includes data collation and the CPU to GPU data copy.

class FakeDataset(VisionDataset):
    def __init__(self, transform):
        super().__init__(root=None, transform=transform)
        self.size = 10000
        self.cache = None

    def __getitem__(self, index):
        if self.cache is None:
            # create a random 1024x1024 image
            img = Image.fromarray(np.random.randint(
                low=0,
                high=256,
                size=(input_img_size[0], input_img_size[1], 3),
                dtype=np.uint8
            ))
            # create a random label
            target = np.random.randint(low=0, high=num_classes,
                                       dtype=np.uint8).item()
            # Apply tranformations
            img = self.transform(img)
            self.cache = img, target
        return self.cache

The resultant throughput is 3.23 steps per second— still far higher than our baseline of 0.89. We still have not found the culprit.

Step 3b: Cache Raw Data (Before Transformation)
Next, we modify the dataset so as to cache the raw data (e.g., unprocessed image files). The input data pipeline now includes the data transformations, data collation, and the CPU to GPU data copy.

class FakeDataset(VisionDataset):
    def __init__(self, transform):
        super().__init__(root=None, transform=transform)
        self.size = 10000
        self.cache = None

    def __getitem__(self, index):
        if self.cache is None:
            # create a random 1024x1024 image
            img = Image.fromarray(np.random.randint(
                low=0,
                high=256,
                size=(input_img_size[0], input_img_size[1], 3),
                dtype=np.uint8
            ))
            # create a random label
            target = np.random.randint(low=0, high=num_classes,
                                       dtype=np.uint8).item()
            self.cache = img, target
        # Apply tranformations
        img = self.transform(self.cache[0])
        return img, self.cache[1]

This time, the throughput drops sharply — all the way down to 1.72 steps per second. We have found our first culprit: the data transformation function.

Interim Results

Here’s a summary of the experiments so far:

Caching Experiment Results (by Author)

The results point to a significant slowdown introduced by the data transformation step. The gap between the raw data caching result and the baseline also suggests that raw data loading may be another culprit. Let’s begin with the data processing bottleneck.

Optimizing the Data Transformation

We now proceed with our newfound discovery of a performance bottleneck in the data processing function. The next logical step would be to break the transform function into individual components and apply our caching technique to each one in order to derive more insight into the precise sources of our GPU starvation. For the sake of brevity, we will skip ahead and apply the data processing optimizations discussed in our previous post, Solving Bottlenecks on the Data Input Pipeline with PyTorch Profiler and TensorBoard. Please see there for details.

Following the data transformation optimizations, the throughput of the cached raw data experiment shoots up to 3.23. We have eliminated the bottleneck in the data processing function.

However, our new baseline throughput (without caching) becomes 1.28 steps per second, indicating that there remains a bottleneck in the raw data loading. This is similar to the end result we reached in our previous post.

Throughput Following Transform Optimization (by Author)

Optimizing Raw Data Loading

To resolve the remaining bottleneck, we simulate the optimization demonstrated in part 5 of this series, How to Optimize Your DL Data-Input Pipeline with a Custom PyTorch Operator. We do this by reducing the size of our initial random image from 1024×1024 to 256×256. Following, this change the end to end (un-cached) training step increases to 3.23. Problem solved!

Important Caveats

We conclude with a few important notes and caveats.

  1. A drop in throughput resulting from inclusion of a certain data-processing step in the data pipeline, does not necessarily mean that it is that specific step that requires optimization. It is entirely possible that it is another step CPU utilization near the limit, and the new step just tipped it over.
  2. If your input data varies in size, throughput from a single cached data sample or batch of samples may not reflect real-world performance.
  3. The same caveat applies if the AI model includes dynamic, data-dependent , features, e.g., if components of the model graph are dependent on the input data.

Tips, Tricks, and Techniques for Addressing Bottlenecks on the Data Input Pipeline

We conclude this post with a list of tips, tricks, and techniques for optimizing the data input pipeline of PyTorch-based AI models. This list is by no means exhaustive — numerous additional optimizations exist depending on your specific use case and infrastructure. We divide the optimizations into three categories:

  • Optimizing Raw Data Entry/Retrieval
  • Optimizing Data Processing
  • Optimizing Host-to-Device Data Transfer

Optimizing Raw Data Entry/Retrieval

Efficient data loading starts with fast and reliable access to raw data. The following tips can help:

  • Choose an instance type with sufficient network ingress bandwidth.
  • Use a fast and cost-effective data storage solution. Local SSDs are fast but expensive. Cloud-based solutions like S3 offer scalability, but may introduce latency.
  • Maximize storage network egress. Consider partitioning datasets in S3 or tuning parallel downloads to reduce throttling.
  • Consider raw data compression. Compressing files can reduce transfer time — but watch out for increased CPU cost during decompression.
  • Group small samples into larger files. This can reduce overhead associated with opening and closing many files.
  • Use optimized data transfer tools. For example, s5cmd can significantly outperforms AWS CLI for bulk S3 downloads.
  • Tune data retrieval parameters. Adjusting chunk size or concurrency settings can greatly impact read performance.

Addressing Data Processing Bottlenecks

  • Tune the number of data loading workers and prefetch factor.
  • Whenever possible, offload data-processing to the data preparation phase.
  • Choose an instance type with an optimal CPU/GPU compute ratio.
  • Optimize the order of transformations. For example, applying a crop before blurring will be faster blurring the full sized image and only then cropping.
  • Leverage Python acceleration libraries. For example, Numba and JAX can speed up pure Python operations via JIT compilation.
  • Create custom PyTorch CPU operators where appropriate (e.g., see here).
  • Consider adding auxiliary CPUs (data servers) — (e.g., see here).
  • Move GPU-friendly transforms to the GPU graph. Some transforms (e.g., normalization) can be done post-loading on the GPU for better overlap.
  • Tune OS-level thread and memory configurations.

Optimizing the Host to Device Data Copy

  • Use memory pinning and non-blocking data copies to prefetch data directly onto the GPU. Also see the dedicated CudaDataPrefetcher offered by TorchTNT.
  • Postpone int8 to float32 datatype conversions to the GPU to reduce memory copy payload by a factor of 4.
  • If your model is using lower precision floats (e.g., fp16/bfloat16) cast the floats on the CPU to reduce payload by half.
  • Postpone unpacking of one-hot vectors to the GPU — i.e., keep them as label ids until the last possible moment.
  • If you have many binary values, consider using bitmasks to compress the payload. For example, if you have 8 binary maps, consider compressing them into a single uint8.
  • If your input data is sparse, consider using sparse data representations.
  • Avoid unnecessary padding. While zero-padding is a popular technique for dealing with variable sized input samples, it can significantly increase the size of the memory copy. Consider alternative options (e.g., see here).
  • Make sure you are not copying data that you do not actually need on the GPU!!

Summary

While GPUs are considered essential for modern day AI/ML development they come at a steep price. Once you’ve decided to make the necessary investment into their acquisition, you will want to make sure they are being used as much as possible. The last thing you want is for your GPU to sit idle, waiting for input data due to a preventable bottleneck elsewhere in the pipeline.

Unfortunately, such inefficiencies are all too common. In this post, we introduced a simple technique for diagnosing these issues by iteratively caching data at different stages of the input pipeline. By isolating the runtime impact of each pipeline component, this method helps identify specific bottlenecks — whether in raw data loading, preprocessing, or host-to-device transfer.

Of course, the exact implementation will vary across projects and pipelines, but we hope this strategy provides a useful framework for diagnosing and resolving performance issues in your own AI/ML workflows.

Share.

Comments are closed.