import torch import numpy as np # 1. 张量和 numpy 的转换 def test01(): data_tensor = torch.tensor([2, 3, 4]) # 将张量转换为 numpy 数组 data_numpy = data_tensor.numpy() # data_tensor 和 data_numpy 共享内存 # 修改其中任意一个变量, 另外一个也会发生改变 data_tensor[0] = 100 print(data_numpy) data_numpy[0] = 200 print(data_tensor) # 2. 写时拷贝 def test02(): data_numpy = np.array([2, 3, 4]) data_numpy[0] = 100 # data_tensor 与 data_numpy 共享内存 # 修改 data_numpy 会导致 data_tensor 发生改变 data_tensor = torch.from_numpy(data_numpy) print(data_tensor) # 当 data_tensor 修改时, 发生写时拷贝(延迟拷贝) # data_numpy 并不会改变 data_tensor[0] = 200 print(data_numpy) # 3. 标量张量和数字的转换 def test03(): # 当张量只包含一个元素时, 可以通过 item 函数提取出该值 data = torch.tensor([30,]) print(data.item()) data = torch.tensor(30) print(data.item()) if __name__ == '__main__': test02()
PyTorch Tensor 类型转换
未经允许不得转载:一亩三分地 » PyTorch Tensor 类型转换