CUDA Prefetcher
import torch
class CUDAPrefetcher:
def __init__(self, loader, device="cuda"):
self.loader = loader
self.device = device
self.stream = torch.cuda.Stream()
def _to_device_async(self, batch):
x, y = batch
# This enqueues the copy onto *self.stream* (not the default stream)
with torch.cuda.stream(self.stream):
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
return x, y
def __iter__(self):
it = iter(self.loader)
# Prime the pipeline: start copying batch0
next_batch = next(it)
next_x, next_y = self._to_device_async(next_batch)
for batch in it:
torch.cuda.current_stream().wait_stream(self.stream)
next_x, next_y = self._to_device_async(batch)
yield x, y
x, y = next_x, next_y
# Drain last prefetched batch
torch.cuda.current_stream().wait_stream(self.stream)
yield next_x, next_y
Usage
# ---- usage example ----
# loader = DataLoader(dataset, batch_size=..., num_workers=..., pin_memory=True, persistent_workers=True)
# model = model.to(device)
def train_one_epoch(model, loader, optimizer, loss_fn):
model.train()
prefetch = PrefetchLoader(loader, device=device)
for x, y in prefetch:
optimizer.zero_grad(set_to_none=True)
out = model(x)
loss = loss_fn(out, y)
loss.backward()
optimizer.step()