PyTorch 自定义算子(operator)

PyTorch 算子是 PyTorch 框架的核心组成部分,用于构建神经网络模型、执行计算任务以及进行张量操作。它们提供了丰富的功能、高效的性能和灵活的开发方式,是深度学习开发中不可或缺的工具。

什么是 PyTorch 算子?

在 PyTorch 中,算子(operator) 指的是用于执行张量(Tensor)计算的基本操作,它们通常是 PyTorch 内部定义的函数或运算符。PyTorch 提供了一系列高效的算子,用于张量运算、自动微分、神经网络计算等。例如:

  • 数学运算torch.addtorch.subtorch.multorch.div
  • 矩阵运算torch.matmultorch.transpose
  • 激活函数torch.relu)、torch.sigmoid
  • 张量变换torch.reshapetorch.cat

当然,我们也可以自定义算子。首先创建一个自定义算子类并继承 autograd.Function,需要实现两个静态的方法 forward 和 backward。应用该算子时,调用 apply 方法,不要直接调用 forward 方法。

forward 静态方法中第一个参数为 ctx,它可以理解 Function 对象本身,其方法 save_for_backward 用于在前向计算时,将反向计算用到的中间结果进行缓存。在 backward 静态方法中,也有 ctx 方法,我们可以通过它的 saved_tensors 来取出前向计算时缓存的中间结果。

前向计算方法 forward 的其他参数为该 op 需要传递进来的参数,例如:该 op 用作乘法运算,则其他的两个参数就是需要相乘的两个张量。方法 backward 的另一个参数为 grad_outputs,我们知道神经网络是通过链式求导的方法来计算参数梯度,这里的 grad_outputs 为上一步输出的梯度。backward 方法最终要返回相应张量的梯度,例如:forward 函数中按顺序传入了 w,x 则在 backward 中计算完 w 和 x 梯度之后,以此返回即可。

示例代码:

import torch
from torchviz import make_dot


class CustomMul(torch.autograd.Function):

    @staticmethod
    def forward(ctx, w, x):
        # print('CustomMul forward')
        ctx.save_for_backward(w, x)
        return w * x

    @staticmethod
    def backward(ctx, grad_outputs):
        # print('CustomMul backward')
        # print('CustomMul grad_outputs:', grad_outputs)
        w = ctx.saved_tensors[0]
        x = ctx.saved_tensors[1]

        w_grad = grad_outputs * x
        x_grad = grad_outputs * w

        return w_grad, x_grad


class CustomAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, b):
        # print('CustomAdd forward')
        return x + b

    @staticmethod
    def backward(ctx, grad_outputs):
        # print('CustomAdd backward')
        # print('CustomAdd grad_outputs:', grad_outputs)

        b_grad = grad_outputs * 1
        x_grad = grad_outputs * 1

        return x_grad, b_grad


def mul(w, input):
    return CustomMul.apply(w, input)

def add(b, input):
    return CustomAdd.apply(b, input)


if __name__ == '__main__':

    w = torch.tensor([[2.0, 3.0]], requires_grad=True)
    b = torch.tensor([3.0], requires_grad=True)
    x = torch.tensor([[4.0, 5.0]])

    outputs = mul(w, x)
    outputs = add(b, torch.sum(outputs))

    dot = make_dot(outputs, show_attrs=True)
    dot.render('temp', format='png', cleanup=False)

    outputs.backward()
    print(w.grad, x.grad, b.grad)

程序执行结果:

tensor([[4., 5.]], dtype=torch.float64) None tensor([1.])

计算图如下:

最后,可以通过 torch.autograd.gradcheck 检查下自定义的 CustomMul 工作是否正常:

print(torch.autograd.gradcheck(func=CustomMul.apply, inputs=(w, x)))

输出结果为:True

未经允许不得转载:一亩三分地 » PyTorch 自定义算子(operator)
评论 (0)

6 + 7 =