PyTorch 自定义算子

创建 autograd.Function 的子类,需要实现两个静态的方法 forward 和 backward。应用该 op 时,调用 apply 方法,不要直接调用 forward 方法。

forward 静态方法中第一个参数为 ctx,它可以理解 Function 对象本身,其方法 save_for_backward 用于在前向计算时,将反向计算用到的中间结果进行缓存。在 backward 静态方法中,也有 ctx 方法,我们可以通过它的 saved_tensors 来取出前向计算时缓存的中间结果。

前向计算方法 forward 的其他参数为该 op 需要传递进来的参数,例如:该 op 用作乘法运算,则其他的两个参数就是需要相乘的两个张量。后向计算方法 backward 的另一个参数为 grad_outputs,我们知道神经网络是通过链式求导的方法来计算参数梯度,这里的 grad_outputs 为上一个步输出的梯度。backward 方法最终要返回相应张量的梯度,例如:forward 函数中按顺序传入了 w,x 则在 backward 中计算完 w 和 x 梯度之后,以此返回即可。

示例代码:

import torch
from torchviz import make_dot


class CustomMul(torch.autograd.Function):

    @staticmethod
    def forward(ctx, w, x):
        # print('CustomMul forward')
        ctx.save_for_backward(w, x)
        return w * x

    @staticmethod
    def backward(ctx, grad_outputs):
        # print('CustomMul backward')
        # print('CustomMul grad_outputs:', grad_outputs)
        w = ctx.saved_tensors[0]
        x = ctx.saved_tensors[1]

        w_grad = grad_outputs * x
        x_grad = grad_outputs * w

        return w_grad, x_grad


class CustomAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, b):
        # print('CustomAdd forward')
        return x + b

    @staticmethod
    def backward(ctx, grad_outputs):
        # print('CustomAdd backward')
        # print('CustomAdd grad_outputs:', grad_outputs)

        b_grad = grad_outputs * 1
        x_grad = grad_outputs * 1

        return x_grad, b_grad


def mul(w, input):
    return CustomMul.apply(w, input)

def add(b, input):
    return CustomAdd.apply(b, input)


if __name__ == '__main__':

    w = torch.tensor([[2.0, 3.0]], requires_grad=True)
    b = torch.tensor([3.0], requires_grad=True)
    x = torch.tensor([[4.0, 5.0]])

    outputs = mul(w, x)
    outputs = add(b, torch.sum(outputs))

    dot = make_dot(outputs, show_attrs=True)
    dot.render('temp', format='png', cleanup=False)

    outputs.backward()
    print(w.grad, x.grad, b.grad)

程序执行结果:

tensor([[4., 5.]], dtype=torch.float64) None tensor([1.])

计算图如下:

最后,可以通过 torch.autograd.gradcheck 检查下自定义的 CustomMul 工作是否正常:

print(torch.autograd.gradcheck(func=CustomMul.apply, inputs=(w, x)))

输出结果为:True

未经允许不得转载:一亩三分地 » PyTorch 自定义算子
评论 (0)

4 + 3 =