itertools
isslice function can be quite helpful:
from itertools import islice
for batch, (X, y) in enumerate(tqdm(islice(train_dataloader, 10), total=10)):
X, y = X.to(DEVICE), y.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
pred_indices = torch.argmin(pred, dim=1)
num_correct_train += (pred_indices == y).sum()