itertools

isslice function can be quite helpful:

from itertools import islice
 
for batch, (X, y) in enumerate(tqdm(islice(train_dataloader, 10), total=10)):
    X, y = X.to(DEVICE), y.to(DEVICE)
    optimizer.zero_grad(set_to_none=True)
    pred = model(X)
    loss = loss_fn(pred, y)
    loss.backward()
    optimizer.step()
 
    pred_indices = torch.argmin(pred, dim=1)
    num_correct_train += (pred_indices == y).sum()