Newton-Shulz Iteration

The Newton-Schulz iteration is a quadratically convergent, inversion-free method for computing the sign function of a matrix. It is advantageous over other methods for high-performance computing because it is rich in matrix-matrix multiplications.

Used in the Muon Optimizer.

What Muon does differently from Adam

Adam scales each gradient element independently (per-element moments). Muon instead takes the whole gradient matrix and finds the nearest orthogonal matrix to it. That’s the update it applies.

Why? An orthogonal update means every direction gets equal-magnitude change — no direction gets a disproportionately large or small step. It’s a way of normalizing the gradient that respects its matrix structure rather than treating it as a flat vector.

What Newton-Schulz does

Computing the nearest orthogonal matrix is called finding the “polar factor” — mathematically it’s G @ (G^T @ G)^{-1/2}. You could do this with SVD, but SVD is slow on GPU.

Newton-Schulz is an iterative approximation that converges to the same result using only matrix multiplies (which GPUs are great at):

  1. Normalize the gradient so it has unit norm (line 103)

  2. Iterate (lines 107-110):

    A = X @ X.T          # how far X is from orthogonal (would be I if orthogonal)
    B = b * A + c * A @ A  # correction term (polynomial in A)
    X = a * X + B @ X      # refine X toward orthogonality

    Each iteration makes X closer to orthogonal. The magic constants a, b, c are tuned coefficients for a 5th-order polynomial that maximizes convergence speed.

  3. After ~5 steps, X is approximately orthogonal — X @ X.T ≈ I.

The “zeropower” name

It’s called “zeropower” because computing (G^T @ G)^{-1/2} is the matrix raised to the -1/2 power, and the combined result G @ (G^T @ G)^{-1/2} effectively raises the singular values to the 0th power (i.e., all singular values become 1), while preserving the singular vectors. That’s exactly what makes the result orthogonal.

The transpose trick

Lines 104-106: if the matrix is tall (more rows than columns), it transposes first so the iteration works on the smaller dimension, then transposes back. Just an efficiency trick.