PyTorch Tensor 拼接操作

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在残差网络、注意力机制中都使用到了张量拼接。

  1. torch.cat 函数的使用
  2. torch.stack 函数的使用

1. torch.cat 函数的使用

torch.cat 函数主要用于根据指定的维度将两个张量拼接到一起,用于拼接的两个张量的维度一般相同。

import torch


def test():

    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])

    print(data1)
    print(data2)
    print('-' * 50)

    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data)
    print('-' * 50)

    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data)
    print('-' * 50)

    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

输出结果:

tensor([[[1, 0, 3, 9],
         [4, 1, 7, 3],
         [5, 6, 5, 1],
         [8, 8, 1, 3],
         [3, 6, 8, 9]],

        [[0, 0, 9, 3],
         [3, 9, 3, 5],
         [9, 2, 8, 6],
         [5, 3, 6, 9],
         [6, 2, 0, 4]],

        [[0, 7, 3, 0],
         [0, 2, 8, 2],
         [4, 9, 6, 8],
         [0, 7, 9, 9],
         [9, 9, 8, 1]]])
tensor([[[4, 6, 3, 0],
         [4, 4, 4, 3],
         [0, 2, 4, 6],
         [5, 6, 6, 7],
         [0, 7, 1, 6]],

        [[0, 2, 5, 8],
         [8, 2, 1, 8],
         [9, 4, 9, 7],
         [3, 6, 7, 8],
         [4, 8, 7, 0]],

        [[6, 0, 2, 5],
         [2, 4, 6, 3],
         [3, 7, 7, 0],
         [8, 6, 0, 0],
         [6, 7, 3, 8]]])
--------------------------------------------------
tensor([[[1, 0, 3, 9],
         [4, 1, 7, 3],
         [5, 6, 5, 1],
         [8, 8, 1, 3],
         [3, 6, 8, 9]],

        [[0, 0, 9, 3],
         [3, 9, 3, 5],
         [9, 2, 8, 6],
         [5, 3, 6, 9],
         [6, 2, 0, 4]],

        [[0, 7, 3, 0],
         [0, 2, 8, 2],
         [4, 9, 6, 8],
         [0, 7, 9, 9],
         [9, 9, 8, 1]],

        [[4, 6, 3, 0],
         [4, 4, 4, 3],
         [0, 2, 4, 6],
         [5, 6, 6, 7],
         [0, 7, 1, 6]],

        [[0, 2, 5, 8],
         [8, 2, 1, 8],
         [9, 4, 9, 7],
         [3, 6, 7, 8],
         [4, 8, 7, 0]],

        [[6, 0, 2, 5],
         [2, 4, 6, 3],
         [3, 7, 7, 0],
         [8, 6, 0, 0],
         [6, 7, 3, 8]]])
--------------------------------------------------
tensor([[[1, 0, 3, 9],
         [4, 1, 7, 3],
         [5, 6, 5, 1],
         [8, 8, 1, 3],
         [3, 6, 8, 9],
         [4, 6, 3, 0],
         [4, 4, 4, 3],
         [0, 2, 4, 6],
         [5, 6, 6, 7],
         [0, 7, 1, 6]],

        [[0, 0, 9, 3],
         [3, 9, 3, 5],
         [9, 2, 8, 6],
         [5, 3, 6, 9],
         [6, 2, 0, 4],
         [0, 2, 5, 8],
         [8, 2, 1, 8],
         [9, 4, 9, 7],
         [3, 6, 7, 8],
         [4, 8, 7, 0]],

        [[0, 7, 3, 0],
         [0, 2, 8, 2],
         [4, 9, 6, 8],
         [0, 7, 9, 9],
         [9, 9, 8, 1],
         [6, 0, 2, 5],
         [2, 4, 6, 3],
         [3, 7, 7, 0],
         [8, 6, 0, 0],
         [6, 7, 3, 8]]])
--------------------------------------------------
tensor([[[1, 0, 3, 9, 4, 6, 3, 0],
         [4, 1, 7, 3, 4, 4, 4, 3],
         [5, 6, 5, 1, 0, 2, 4, 6],
         [8, 8, 1, 3, 5, 6, 6, 7],
         [3, 6, 8, 9, 0, 7, 1, 6]],

        [[0, 0, 9, 3, 0, 2, 5, 8],
         [3, 9, 3, 5, 8, 2, 1, 8],
         [9, 2, 8, 6, 9, 4, 9, 7],
         [5, 3, 6, 9, 3, 6, 7, 8],
         [6, 2, 0, 4, 4, 8, 7, 0]],

        [[0, 7, 3, 0, 6, 0, 2, 5],
         [0, 2, 8, 2, 2, 4, 6, 3],
         [4, 9, 6, 8, 3, 7, 7, 0],
         [0, 7, 9, 9, 8, 6, 0, 0],
         [9, 9, 8, 1, 6, 7, 3, 8]]])

2. torch.stack 函数的使用

torch.cat 函数主要用于根据指定的维度将两个张量叠加到一起,用于拼接的两个张量的维度一般相同,其结果会使得数据增加一维。

import torch


def test():

    data1= torch.randint(0, 10, [2, 3])
    data2= torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)
    print('-' * 50)

    new_data = torch.stack([data1, data2], dim=0)
    print(new_data)
    print('-' * 50)

    new_data = torch.stack([data1, data2], dim=1)
    print(new_data)
    print('-' * 50)

    new_data = torch.stack([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

输出结果:

tensor([[2, 9, 8],
        [9, 0, 1]])
tensor([[7, 3, 6],
        [1, 0, 3]])
--------------------------------------------------
tensor([[[2, 9, 8],
         [9, 0, 1]],

        [[7, 3, 6],
         [1, 0, 3]]])
--------------------------------------------------
tensor([[[2, 9, 8],
         [7, 3, 6]],

        [[9, 0, 1],
         [1, 0, 3]]])
--------------------------------------------------
tensor([[[2, 7],
         [9, 3],
         [8, 6]],

        [[9, 1],
         [0, 0],
         [1, 3]]])

至此,本篇文章结束。

未经允许不得转载:一亩三分地 » PyTorch Tensor 拼接操作
评论 (0)

7 + 2 =