CUDA Prefetcher

import torch
 
class CUDAPrefetcher:
    def __init__(self, loader, device="cuda"):
        self.loader = loader
        self.device = device
        self.stream = torch.cuda.Stream()
 
    def _to_device_async(self, batch):
        x, y = batch
        # This enqueues the copy onto *self.stream* (not the default stream)
        with torch.cuda.stream(self.stream):
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)
        return x, y
 
    def __iter__(self):
        it = iter(self.loader)
 
        # Prime the pipeline: start copying batch0
        next_batch = next(it)
        next_x, next_y = self._to_device_async(next_batch)
 
        for batch in it:
            torch.cuda.current_stream().wait_stream(self.stream)
            next_x, next_y = self._to_device_async(batch)
            yield x, y
            x, y = next_x, next_y
 
        # Drain last prefetched batch
        torch.cuda.current_stream().wait_stream(self.stream)
        yield next_x, next_y
 

Usage

# ---- usage example ----
# loader = DataLoader(dataset, batch_size=..., num_workers=..., pin_memory=True, persistent_workers=True)
# model = model.to(device)
 
def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train()
    prefetch = PrefetchLoader(loader, device=device)
 
    for x, y in prefetch:
        optimizer.zero_grad(set_to_none=True)
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()