PIL Image与tensor在PyTorch图像预处理时的转换
前言:在使用深度学习框架PyTorch预处理图像数据时,你可能和我一样遇到过各种各样的问题,网上虽然总能找到类似的问题,但不同文章的代码环境不同,也不一定能直接解决自己的问题。这时,就需要就自身所出bug了解问题本身涉及的大致原理,依据报错的具体位置( 要完整的看完bug信息,不要只看最后报错信息而不看中间调用过程 )才能更快的精准解决自己的问题
一、原理概述
PIL(Python Imaging Library)是Python中最基础的图像处理库,而使用PyTorch将
原始输入图像
预处理为
神经网络的输入
,经常需要用到
三种格式PIL Image、Numpy和Tensor
,其中预处理包括但不限于「
图像裁剪
」,「
图像旋转
」和「
图像数据归一化
」等。而对图像的多种处理在code中可以打包到一起执行,一般用
transforms.Compose(transforms)
将多个transform组合起来使用。如下所示
from torchvision import transforms
transform = transforms.Compose([
# 重置大小
transforms.Resize(255),
transforms.CenterCrop(224),
# 随机旋转图片
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# 正则化(降低模型复杂度)
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
其中,不同图像处理方法要求输入的图像格式不同,比如
Resize()
和
RandomHorizontalFlip()
等方法要求输入的图像为
PIL Image
,而正则化操作
Normalize()
处理的是
tensor
格式的图像数据。因此,针对不同操作的数据格式要求,我们需要
在不同操作之前将输入图像数据的格式化成所要求的格式
,有了这些概念了解,面对可能出现的bug,我们才能游刃有余的精准处理。
二、PIL Image与tensor的转换
2.1 tensor转换为PIL Image
from torchvision.transforms
PIL_img = transforms.ToPILImage()(tensor_img)
2.2 PIL Image转换为tensor
一般放在
transforms.Compose(transforms)
组合中正则化操作的前面即可
transforms.ToTensor()
2.3 Numpy转换为PIL Image
from PIL import Image
PIL_img = Image.fromarray(array)
三、可能遇到的问题
3.1 img should be PIL Image. Got <class ‘torch.Tensor’>
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
这个问题,网上大部分博文甚至stackoverflow上说的都是
transforms.Compose(transforms)
组合中的顺序问题,但按照这些说法修改顺序后我仍一直未解决问题。后来了解了原理并结合自己实际bug出现的位置,才最终解决。
如下图所示,我的bug出现在红框中的句柄中,而与大多数博文不同的是,我是先对图像做灰度处理,然后再做剪裁和旋转的操作,因此
transforms.Compose(transforms)
组合操作在这行代码之后,自然怎么改顺序都无动于衷。所以从bug的位置可知此问题与组合操作顺序无关,但从最后的类型错误中可知此行代码传进去的observation类型期望是PIL,但实际是tensor,因此只要在此之前进行两者格式的转换即可解决bug
解决方案从
transform = T.Grayscale()
img = transform(img)
变为
transform = T.Grayscale()
img = T.ToPILImage()(img)
img = transform(img)
3.1 tensor should be a torch tensor. Got <class ‘PIL.Image.Image’>.
TypeError: tensor should be a torch tensor. Got <class 'PIL.Image.Image'>.
肯定是需要
tensor
的图像操作传入的是
PIL
,因此在合适的位置前将
PIL
转换为
tensor
即可
解决方法从
transform = transforms.Compose([
transforms.Resize(255),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
到
transform = transforms.Compose([