轻松学Pytorch-使用torchvision的transforms实现图像预处理
Transforms包介绍
Pytorch中的图像预处理都跟transforms这个工具包有关系,它是一个常用的图像变换工具包,主要支持方式有两中:
Compose方式 ,支持链式处理,可以集合多个transforms的方法或者类。Compose方式的例子如下:
transforms.Compose([
transforms.CenterCrop(10), // 剪切为10x10大小
transforms.ToTensor(), // 像素值转换为0~1
])
Scriptable transforms方式 ,通过即时运行的脚本方式实现图像变换。例子图示如下:
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
scripted_transforms = torch.jit.script(transforms)
当以script脚本形式运行时候,transfroms必须使用torch.nn.Module继承类实现链式处理流程的定义与组合。
官方说明上述两种变换方式均支持PIL图像对象与Tensor对象,输入的图像格式必须为以下:
(C、H、W) 一张图像变换
或者
(B、C、H、W) 多张图像变换
其中C表示图像通道数、H表示图像高度、W表示图像宽度,B表示batch数目
常用图像转换类功能列表
常见的torchvision.transforms的类与功能如下:
torchvision.transforms.CenterCrop // 中心剪切
torchvision.transforms.ColorJitter // 颜色颜色,支持亮度、饱和度、色泽
torchvision.transforms.FiveCrop // 5次剪切,把图像剪切为四个部分+中间部分
torchvision.transforms.Grayscale // 灰度转换
torchvision.transforms.Pad // 填充
torchvision.transforms.RandomAffine // 随机几何变换,支持错切、平移、旋转等
torchvision.transforms.RandomApply // 对多个transfrom的随机应用
torchvision.transforms.RandomCrop // 随机剪切
torchvision.transforms.RandomGrayscale // 随机灰度
torchvision.transforms.RandomHorizontalFlip // 随机水平翻转
torchvision.transforms.RandomPerspective // 随机透视变换,参数可设置
torchvision.transforms.RandomResizedCrop // 随机剪切+尺度
torchvision.transforms.RandomRotation // 随机旋转
torchvision.transforms.RandomVerticalFlip // 随机垂直翻转
torchvision.transforms.GaussianBlur // 随机高斯模糊,模糊程度随机
torchvision.transforms.LinearTransformation // 图像随机线性变换
torchvision.transforms.Normalize // 归一化
torchvision.transforms.RandomErasing // 随机擦除
torchvision.transforms.ToPILImage // 转换为PIL图像输出
此外还这支持单独的功能函数相关的方法,通过torchvision.transforms.functional实现支持。
scriptable方式的代码变换演示
中心剪切+归一化,代码如下:
transforms = torch.nn.Sequential(
tf.CenterCrop(400),
tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
scripted_transforms = torch.jit.script(transforms)

随机剪切+灰度化
torch.manual_seed(17)
transforms = torch.nn.Sequential(
tf.RandomCrop(300),
tf.Grayscale()
scripted_transforms = torch.jit.script(transforms)
其中300表示剪切输出的图像大小为300x300,运行结果:

随机旋转 + 随机灰度,代码如下:
torch.manual_seed(17)
transforms = torch.nn.Sequential(
tf.RandomRotation((35, 135), resample=0),
tf.RandomGrayscale()
scripted_transforms = torch.jit.script(transforms)
其中35表示最小旋转角度、135表示最大旋转角度,此范围内随机旋转。
运行结果如下:

高斯模糊操作:
torch.manual_seed(17)
transforms = torch.nn.Sequential(
tf.GaussianBlur(kernel_size=15, sigma=(5.0, 15.0))
scripted_transforms = torch.jit.script(transforms)
运行结果如下:

随机翻转 + 归一化
torch.manual_seed(17)
transforms = torch.nn.Sequential(
tf.RandomHorizontalFlip(),
tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
scripted_transforms = torch.jit.script(transforms)
运行结果如下:

运行上述的测试代码,我没有使用PIL库,而是使用OpenCV完成了图像读取与处理显示,代码如下:
import torch
import cv2 as cv
import numpy as np
import torchvision.transforms as tf
torch.manual_seed(17)
transforms = torch.nn.Sequential(
tf.Grayscale()
scripted_transforms = torch.jit.script(transforms)
image = cv.imread("D:/images/1024.png")
cv.imshow("input", image)