PyTorch Tensor 类型转换

import torch
import numpy as np


# 1. 张量和 numpy 的转换
def test01():

    data_tensor = torch.tensor([2, 3, 4])
    # 将张量转换为 numpy 数组
    data_numpy = data_tensor.numpy()

    # data_tensor 和 data_numpy 共享内存
    # 修改其中任意一个变量, 另外一个也会发生改变
    data_tensor[0] = 100
    print(data_numpy)

    data_numpy[0] = 200
    print(data_tensor)


# 2. 写时拷贝
def test02():

    data_numpy = np.array([2, 3, 4])
    data_numpy[0] = 100

    # data_tensor 与 data_numpy 共享内存
    # 修改 data_numpy 会导致 data_tensor 发生改变
    data_tensor = torch.from_numpy(data_numpy)
    print(data_tensor)

    # 当 data_tensor 修改时, 发生写时拷贝(延迟拷贝)
    # data_numpy 并不会改变
    data_tensor[0] = 200
    print(data_numpy)


# 3. 标量张量和数字的转换
def test03():

    # 当张量只包含一个元素时, 可以通过 item 函数提取出该值
    data = torch.tensor([30,])
    print(data.item())

    data = torch.tensor(30)
    print(data.item())


if __name__ == '__main__':
    test02()
未经允许不得转载:一亩三分地 » PyTorch Tensor 类型转换
评论 (0)

3 + 8 =