我们在使用 PyTorch 建网络模型时,网络层与层之间很多都是使用不同的 shape 进行运算,我们需要掌握对张量形状的操作,以便能够更好处理网络各层之间的数据连接。
- reshape 函数
- transpose 和 permute 函数
- view 和 contigous 函数
- squeeze 和 unsqueeze 函数
1. reshape 函数的用法
reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状。
import torch import numpy as np def test(): data = torch.tensor([[10, 20, 30], [40, 50, 60]]) # 1. 使用 shape 属性或者 size 方法都可以获得张量的形状 print(data.shape, data.shape[0], data.shape[1]) print(data.size(), data.size(0), data.size(1)) # 2. 使用 reshape 函数修改张量形状 new_data = data.reshape(1, 6) print(new_data.shape) if __name__ == '__main__': test()
输出结果:
torch.Size([2, 3]) 2 3 torch.Size([2, 3]) 2 3 torch.Size([1, 6])
2. transpose 和 permute 函数的使用
transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3),permute 函数可以一次交换更多的维度。
import torch import numpy as np def test(): data = torch.tensor(np.random.randint(0, 10, [3, 4, 5])) print('data shape:', data.size()) # 1. 交换1和2维度 new_data = torch.transpose(data, 1, 2) print('data shape:', new_data.size()) # 2. 将 data 的形状修改为 (4, 5, 3) new_data = torch.transpose(data, 0, 1) new_data = torch.transpose(new_data, 1, 2) print('new_data shape:', new_data.size()) # 3. 使用 permute 函数将形状修改为 (4, 5, 3) new_data = torch.permute(data, [1, 2, 0]) print('new_data shape:', new_data.size()) if __name__ == '__main__': test()
输出结果:
data shape: torch.Size([3, 4, 5]) data shape: torch.Size([3, 5, 4]) new_data shape: torch.Size([4, 5, 3]) new_data shape: torch.Size([4, 5, 3])
3. view 和 contigous 函数的用法
view 函数也可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如: 一个张量经过了 transpose 或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作。
import torch import numpy as np def test(): data = torch.tensor([[10, 20, 30], [40, 50, 60]]) print('data shape:', data.size()) # 1. 使用 view 函数修改形状 new_data = data.view(3, 2) print('new_data shape:', new_data.shape) # 2. 判断张量是否使用整块内存 print('data:', data.is_contiguous()) # True # 3. 使用 transpose 函数修改形状 new_data = torch.transpose(data, 0, 1) print('new_data:', new_data.is_contiguous()) # False # new_data = new_data.view(2, 3) # RuntimeError # 需要先使用 contiguous 函数转换为整块内存的张量,再使用 view 函数 print(new_data.contiguous().is_contiguous()) new_data = new_data.contiguous().view(2, 3) print('new_data shape:', new_data.shape) if __name__ == '__main__': test()
输出结果:
data shape: torch.Size([2, 3]) new_data shape: torch.Size([3, 2]) data: True new_data: False True new_data shape: torch.Size([2, 3])
4. squeeze 和 unsqueeze 函数的用法
去掉张量中数值为1的维度,如果张量的形状为: (A, 1, B, C, 1, D),那么输出张量的形状为: (A, B, C, D)。
当指定 dim 参数时,维度压缩操作只会在指定的维度上进行。
unsqueeze 是 squeeze 函数的反向操作, 可以用于增加维度。
注意:
1. 新张量与原张量共享内存,其中的一个发生变化,另外一个张量也会发生改变。
2. 如果一个张量只有1个维度,那么它不会受到上述方法的影响。
import torch import numpy as np def test(): data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5])) print('data shape:', data.size()) # 1. 去掉值为1的维度 new_data = data.squeeze() print('new_data shape:', new_data.size()) # torch.Size([3, 5]) # 2. 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除 new_data = data.squeeze(2) print('new_data shape:', new_data.size()) # torch.Size([3, 5]) # 3. 在2维度增加一个维度 new_data = data.unsqueeze(-1) print('new_data shape:', new_data.size()) # torch.Size([3, 1, 5, 1]) if __name__ == '__main__': test()
输出结果:
data shape: torch.Size([1, 3, 1, 5]) new_data shape: torch.Size([3, 5]) new_data shape: torch.Size([1, 3, 5]) new_data shape: torch.Size([1, 3, 1, 5, 1])
至此,本篇文章结束。