我们在操作张量时,经常需要去进行获取或者修改操作,掌握张量的花式索引操作是必须的一项能力。
import torch data = torch.randint(0, 10, [4, 5]) print(data) print('-' * 50) # 1. 简单行、列索引 def test01(): print(data[0]) print(data[:, 0]) print('-' * 50) # 2. 列表索引 def test02(): # 返回 (0, 1)、(1, 2) 两个位置的元素 print(data[[0, 1], [1, 2]]) print('-' * 50) # 返回 0、1 行的 1、2 列共4个元素 print(data[[[0], [1]], [1, 2]]) # 3. 范围索引 def test03(): # 前3行的前2列数据 print(data[:3, :2]) # 第2行到最后的前2列数据 print(data[2:, :2]) # 4. 布尔索引 def test04(): # 第三列大于5的行数据 print(data[data[:, 2] > 5]) # 第二行大于5的列数据 print(data[:, data[1] > 5]) # 5. 多维索引 def test05(): data = torch.randint(0, 10, [3, 4, 5]) print(data) print('-' * 50) print(data[0, :, :]) print(data[:, 0, :]) print(data[:, :, 0]) if __name__ == '__main__': test05()