PyTorch 图像裁剪

torchvision 提供了一些用于图像裁剪的方法,这些方法也可以用于图像增强。主要介绍下:

  1. PIL 和 Tensor 转换
  2. 多种图像裁剪方法

1. PIL 和 Tensor 转换

处理图像之前,需要先读取图片。我们可以使用 torchvision.io.image 模块的 read_image 方法读取,但是该方法只能读取 jpeg 或者 png 图片。另外一种方法,我们可以使用 PIL.Image 的 open 方法来读取,它能够支持的图片类型更多一些,但是读取的图片并不是 Tensor 类型,而是 PIL 图像,我们就需要调用 torchvision 提供的一些方法来将其转换成 Tensor,以便于后续的裁剪处理。

这里需要注意一点,matplotlib.pyplot.imshow 方法只能显示 PIL 图片,所以还得能够将处理后的 Tensor 图像转换回 PIL 图像。

示例代码:

from PIL import Image
from torchvision.transforms import ToTensor
from torchvision.transforms import PILToTensor
from torchvision.transforms import ToPILImage
import torchvision.transforms.functional as TF
import numpy as np

# PIL 和 Tensor 转换
def test01():

    img = Image.open('demo.webp')
    print(img)

    # 将 PIL 或者 ndarray 的 (H x W x C) [0, 255] 转换为 (C x H x W) [0.0, 1.0]
    img1 = ToTensor()(img)
    img2 = TF.to_tensor(img)
    print(img1.shape, torch.equal(img1, img2))

    # 将 PIL 的 (H x W x C) [0, 255] 转换为 (C x H x W) [0, 255]
    img3 = PILToTensor()(img)
    img4 = TF.pil_to_tensor(img)
    print(img2.shape, torch.equal(img3, img4))

    # 将 Tensor (C x H x W) 或者 ndarray (H x W x C) 转换为 PIL,并保留取值范围
    img5 = ToPILImage()(img1)
    img6 = TF.to_pil_image(img2)
    print(img5)
    print(img6)

if __name__ == '__main__':
    test01()

程序执行结果:

<PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=625x625 at 0x7FCD182B7F90>
torch.Size([3, 625, 625]) True
torch.Size([3, 625, 625]) True
<PIL.Image.Image image mode=RGB size=625x625 at 0x7FCD1837A5D0>
<PIL.Image.Image image mode=RGB size=625x625 at 0x7FCD184DBD50>

2. 多种图像裁剪方法

图像裁剪就是从图像中扣出一部分子图作为新的图像,不同的方法只是扣的方式不同,目的都是相同的。

# 从中心位置开始裁剪出 size 大小的图像
from torchvision.transforms import CenterCrop

# 从图像的4个角和中心位置裁剪出 size 大小的5个图像
from torchvision.transforms import FiveCrop

# 随机位置裁剪 size 大小的图像
from torchvision.transforms import RandomCrop

# 随机选择原始图像 scale 范围的面积,随机选择长宽比 ratio,最终 resize 到 size 大小
from torchvision.transforms import RandomResizedCrop

# 随机裁剪4个角以及中心位置共5个图像,再加上5个图的随机翻转共10个图像
from torchvision.transforms import TenCrop

# 裁剪指定区间的图像
from torchvision.transforms.functional import crop

示例代码:

from PIL import Image
from torchvision.transforms import ToPILImage
from torchvision.transforms import CenterCrop
from torchvision.transforms import ToTensor
from torchvision.transforms import FiveCrop
from torchvision.transforms import RandomCrop
from torchvision.transforms import RandomResizedCrop
from torchvision.transforms import TenCrop
import torchvision.transforms.functional as TF
from torchvision.io import image
import matplotlib.pyplot as plt
import torch
import math

def show_image(images):

    # 图像绘制几行几列
    img_num = len(images)
    row_num = math.ceil(img_num / 3)
    col_num = 3 if img_num > 3 else img_num

    for idx, img in enumerate(images):
        # Tensor 转换为 PIL
        pil = TF.to_pil_image(img)
        plt.subplot(row_num, col_num, idx + 1)
        plt.imshow(pil)
        plt.xlim(0, img.size(2))
        plt.ylim(img.size(1), 0)
        plt.title('原图' if not idx else '裁剪')

    plt.tight_layout()
    plt.show()

CenterCrop

def test01():

    origin = image.read_image('demo.png')
    # 对原图 img0 以中心位置开始裁剪 size 大小
    crop1 = CenterCrop(size=(500, 300))(origin)
    crop2 = TF.center_crop(origin, output_size=(500, 300))
    show_image([origin, crop1, crop2])

程序执行结果:

FiveCrop

def test02():

    origin = image.read_image('demo.png')
    crops = FiveCrop(size=(400, 400))(origin)
    show_image((origin,) + crops)

程序执行结果:

RandomCrop

def test03():

    origin = image.read_image('demo.png')
    crop = RandomCrop(size=(400, 400))(origin)
    show_image([origin, crop])

程序执行结果:

RandomResizedCrop

def test04():

    origin = image.read_image('demo.png')
    # scale 表示随机从原图中选择 0.8~1.0 面积
    # ratio 表示随机从上界和下界选择宽高比
    crop = RandomResizedCrop(size=(500, 500),
                             scale=(0.8, 1.0),
                             ratio=(3.0/ 4.0, 4.0/3.0))(origin)
    show_image([origin, crop])

程序执行结果:

TenCrop

def test05():

    origin = image.read_image('demo.png')
    crops = TenCrop(size=(400, 400))(origin)
    # crops = TF.ten_crop(origin, size=(400, 400))
    show_image((origin,) + crops)

程序执行结果:

TF.crop

def test06():

    origin = image.read_image('demo.png')
    crop = TF.crop(origin, top=100, left=200, height=300, width=200)
    show_image([origin, crop])

程序执行结果:

未经允许不得转载:一亩三分地 » PyTorch 图像裁剪
评论 (0)

7 + 9 =