Pytorch Autograd
Writing custom autograd functions as a way to speedup, if your torch.compile cannot properly fuse 2 kernels together.
https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html
class MultiplyAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, b):
ctx.save_for_backward(x, w)
return x * w + b
@staticmethod
def backward(ctx, grad_out):
x, w = ctx.saved_tensors
grad_x = grad_out * w
grad_w = grad_out * x
grad_b = grad_out
return grad_x, grad_w, grad_bExample
class SquareReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.relu(x)
ctx.save_for_backward(y)
return y * y
@staticmethod
def backward(ctx, grad_out):
(y,) = ctx.saved_tensors
# d(relu(x)^2)/dx = 2 * relu(x) * 1[x > 0]
grad_x = grad_out * 2 * y * (y > 0)
return grad_x
def square_relu(x):
return SquareReLU.apply(x)Tip: you can check that your analytical gradients are correct by using torch.autograd.gradcheck
from torch.autograd import gradcheck
def fn(x):
return SquareReLU.apply(x)
x = torch.randn(5, dtype=torch.double, requires_grad=True)
print(gradcheck(fn, (x,), eps=1e-6, atol=1e-4))