创建 autograd.Function
的子类,需要实现两个静态的方法 forward 和 backward。应用该 op 时,调用 apply 方法,不要直接调用 forward 方法。
forward 静态方法中第一个参数为 ctx,它可以理解 Function 对象本身,其方法 save_for_backward 用于在前向计算时,将反向计算用到的中间结果进行缓存。在 backward 静态方法中,也有 ctx 方法,我们可以通过它的 saved_tensors 来取出前向计算时缓存的中间结果。
前向计算方法 forward 的其他参数为该 op 需要传递进来的参数,例如:该 op 用作乘法运算,则其他的两个参数就是需要相乘的两个张量。后向计算方法 backward 的另一个参数为 grad_outputs,我们知道神经网络是通过链式求导的方法来计算参数梯度,这里的 grad_outputs 为上一个步输出的梯度。backward 方法最终要返回相应张量的梯度,例如:forward 函数中按顺序传入了 w,x 则在 backward 中计算完 w 和 x 梯度之后,以此返回即可。
示例代码:
import torch from torchviz import make_dot class CustomMul(torch.autograd.Function): @staticmethod def forward(ctx, w, x): # print('CustomMul forward') ctx.save_for_backward(w, x) return w * x @staticmethod def backward(ctx, grad_outputs): # print('CustomMul backward') # print('CustomMul grad_outputs:', grad_outputs) w = ctx.saved_tensors[0] x = ctx.saved_tensors[1] w_grad = grad_outputs * x x_grad = grad_outputs * w return w_grad, x_grad class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, b): # print('CustomAdd forward') return x + b @staticmethod def backward(ctx, grad_outputs): # print('CustomAdd backward') # print('CustomAdd grad_outputs:', grad_outputs) b_grad = grad_outputs * 1 x_grad = grad_outputs * 1 return x_grad, b_grad def mul(w, input): return CustomMul.apply(w, input) def add(b, input): return CustomAdd.apply(b, input) if __name__ == '__main__': w = torch.tensor([[2.0, 3.0]], requires_grad=True) b = torch.tensor([3.0], requires_grad=True) x = torch.tensor([[4.0, 5.0]]) outputs = mul(w, x) outputs = add(b, torch.sum(outputs)) dot = make_dot(outputs, show_attrs=True) dot.render('temp', format='png', cleanup=False) outputs.backward() print(w.grad, x.grad, b.grad)
程序执行结果:
tensor([[4., 5.]], dtype=torch.float64) None tensor([1.])
计算图如下:
最后,可以通过 torch.autograd.gradcheck
检查下自定义的 CustomMul 工作是否正常:
print(torch.autograd.gradcheck(func=CustomMul.apply, inputs=(w, x)))
输出结果为:True