Pytorch Autograd

Writing custom autograd functions as a way to speedup, if your torch.compile cannot properly fuse 2 kernels together.

https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html

class MultiplyAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, b):
        ctx.save_for_backward(x, w)
        return x * w + b
 
    @staticmethod
    def backward(ctx, grad_out):
        x, w = ctx.saved_tensors
 
        grad_x = grad_out * w
        grad_w = grad_out * x
        grad_b = grad_out
 
        return grad_x, grad_w, grad_b

Example

class SquareReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.relu(x)
        ctx.save_for_backward(y)
        return y * y
 
    @staticmethod
    def backward(ctx, grad_out):
        (y,) = ctx.saved_tensors
        # d(relu(x)^2)/dx = 2 * relu(x) * 1[x > 0]
        grad_x = grad_out * 2 * y * (y > 0)
        return grad_x
 
def square_relu(x):
    return SquareReLU.apply(x)

Tip: you can check that your analytical gradients are correct by using torch.autograd.gradcheck

from torch.autograd import gradcheck
 
def fn(x):
    return SquareReLU.apply(x)
 
x = torch.randn(5, dtype=torch.double, requires_grad=True)
print(gradcheck(fn, (x,), eps=1e-6, atol=1e-4))