张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在残差网络、注意力机制中都使用到了张量拼接。
- torch.cat 函数的使用
- torch.stack 函数的使用
1. torch.cat 函数的使用
torch.cat 函数主要用于根据指定的维度将两个张量拼接到一起,用于拼接的两个张量的维度一般相同。
import torch def test(): data1 = torch.randint(0, 10, [3, 5, 4]) data2 = torch.randint(0, 10, [3, 5, 4]) print(data1) print(data2) print('-' * 50) # 1. 按0维度拼接 new_data = torch.cat([data1, data2], dim=0) print(new_data) print('-' * 50) # 2. 按1维度拼接 new_data = torch.cat([data1, data2], dim=1) print(new_data) print('-' * 50) # 3. 按2维度拼接 new_data = torch.cat([data1, data2], dim=2) print(new_data) if __name__ == '__main__': test()
输出结果:
tensor([[[1, 0, 3, 9], [4, 1, 7, 3], [5, 6, 5, 1], [8, 8, 1, 3], [3, 6, 8, 9]], [[0, 0, 9, 3], [3, 9, 3, 5], [9, 2, 8, 6], [5, 3, 6, 9], [6, 2, 0, 4]], [[0, 7, 3, 0], [0, 2, 8, 2], [4, 9, 6, 8], [0, 7, 9, 9], [9, 9, 8, 1]]]) tensor([[[4, 6, 3, 0], [4, 4, 4, 3], [0, 2, 4, 6], [5, 6, 6, 7], [0, 7, 1, 6]], [[0, 2, 5, 8], [8, 2, 1, 8], [9, 4, 9, 7], [3, 6, 7, 8], [4, 8, 7, 0]], [[6, 0, 2, 5], [2, 4, 6, 3], [3, 7, 7, 0], [8, 6, 0, 0], [6, 7, 3, 8]]]) -------------------------------------------------- tensor([[[1, 0, 3, 9], [4, 1, 7, 3], [5, 6, 5, 1], [8, 8, 1, 3], [3, 6, 8, 9]], [[0, 0, 9, 3], [3, 9, 3, 5], [9, 2, 8, 6], [5, 3, 6, 9], [6, 2, 0, 4]], [[0, 7, 3, 0], [0, 2, 8, 2], [4, 9, 6, 8], [0, 7, 9, 9], [9, 9, 8, 1]], [[4, 6, 3, 0], [4, 4, 4, 3], [0, 2, 4, 6], [5, 6, 6, 7], [0, 7, 1, 6]], [[0, 2, 5, 8], [8, 2, 1, 8], [9, 4, 9, 7], [3, 6, 7, 8], [4, 8, 7, 0]], [[6, 0, 2, 5], [2, 4, 6, 3], [3, 7, 7, 0], [8, 6, 0, 0], [6, 7, 3, 8]]]) -------------------------------------------------- tensor([[[1, 0, 3, 9], [4, 1, 7, 3], [5, 6, 5, 1], [8, 8, 1, 3], [3, 6, 8, 9], [4, 6, 3, 0], [4, 4, 4, 3], [0, 2, 4, 6], [5, 6, 6, 7], [0, 7, 1, 6]], [[0, 0, 9, 3], [3, 9, 3, 5], [9, 2, 8, 6], [5, 3, 6, 9], [6, 2, 0, 4], [0, 2, 5, 8], [8, 2, 1, 8], [9, 4, 9, 7], [3, 6, 7, 8], [4, 8, 7, 0]], [[0, 7, 3, 0], [0, 2, 8, 2], [4, 9, 6, 8], [0, 7, 9, 9], [9, 9, 8, 1], [6, 0, 2, 5], [2, 4, 6, 3], [3, 7, 7, 0], [8, 6, 0, 0], [6, 7, 3, 8]]]) -------------------------------------------------- tensor([[[1, 0, 3, 9, 4, 6, 3, 0], [4, 1, 7, 3, 4, 4, 4, 3], [5, 6, 5, 1, 0, 2, 4, 6], [8, 8, 1, 3, 5, 6, 6, 7], [3, 6, 8, 9, 0, 7, 1, 6]], [[0, 0, 9, 3, 0, 2, 5, 8], [3, 9, 3, 5, 8, 2, 1, 8], [9, 2, 8, 6, 9, 4, 9, 7], [5, 3, 6, 9, 3, 6, 7, 8], [6, 2, 0, 4, 4, 8, 7, 0]], [[0, 7, 3, 0, 6, 0, 2, 5], [0, 2, 8, 2, 2, 4, 6, 3], [4, 9, 6, 8, 3, 7, 7, 0], [0, 7, 9, 9, 8, 6, 0, 0], [9, 9, 8, 1, 6, 7, 3, 8]]])
2. torch.stack 函数的使用
torch.cat 函数主要用于根据指定的维度将两个张量叠加到一起,用于拼接的两个张量的维度一般相同,其结果会使得数据增加一维。
import torch def test(): data1= torch.randint(0, 10, [2, 3]) data2= torch.randint(0, 10, [2, 3]) print(data1) print(data2) print('-' * 50) new_data = torch.stack([data1, data2], dim=0) print(new_data) print('-' * 50) new_data = torch.stack([data1, data2], dim=1) print(new_data) print('-' * 50) new_data = torch.stack([data1, data2], dim=2) print(new_data) if __name__ == '__main__': test()
输出结果:
tensor([[2, 9, 8], [9, 0, 1]]) tensor([[7, 3, 6], [1, 0, 3]]) -------------------------------------------------- tensor([[[2, 9, 8], [9, 0, 1]], [[7, 3, 6], [1, 0, 3]]]) -------------------------------------------------- tensor([[[2, 9, 8], [7, 3, 6]], [[9, 0, 1], [1, 0, 3]]]) -------------------------------------------------- tensor([[[2, 7], [9, 3], [8, 6]], [[9, 1], [0, 0], [1, 3]]])
至此,本篇文章结束。