您的当前位置:首页正文

【深度学习】transforms图像预处理操作

2024-11-13 来源:个人技术集锦


前言

transforms这个函数大量出现于CV领域的深度学习,但是我一直没搞懂什么样的数据集应该使用什么样的操作。比如,这些图片是需要裁剪,还是翻转,还是怎么样的,并且它们操作完之后是什么样子的。这一期我们就好好实验一下。

有一点比较重要的是,在每次从torchvision.datasets 接口import 数据集时,都可以要使用transfrom操作(就连最简单的mnist数据集也需要),如果是自制数据集的话,也可以i使用IamgeFolder里面设置transform操作
简单一点就这样:(只有一个操作)

from torchvision import datasets,transforms
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

复杂一点的就这样,有Compose有多个操作:

from torchvision import datasets,transforms
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transforms
)
transforms=transforms.Compose(
    .................
)

注意Compose大写。
我们今天的任务就是好好研究一下里面的命令

一、transforms.compose() 图像预处理操作组合

这个是里面最基本的命令,把多个图像预处理操作放在一起了,不多解释了。

二、transforms.ToTensor() 转为张量

看了四五个项目,总结下来就是:其他的命令也许你可以不用,但是这个命令是必须的。因为需要把图像转为张量。
它让图像数组或image对象转为张量且归一化到(0,1):
我们实验一下:

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt

train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train')
plt.imshow(train[0][0])
plt.show()
print(type(train[0][0]))

train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transforms.ToTensor())
print(train[0][0])
print(type(train[0][0]))

我从猫狗数据集中随便找了一张照片显示:

使用totensor之后:

三、transforms.Normalize()图像标准化操作

这个操作一般在transforms.ToTensor()的后面,具体操作就是对张量进行正态化,标准化(标准高斯分布)

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])#记得打【】 不然要报错
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transform)
print(train[0][0])
print(type(train[0][0]))

结果如下,可以看到转为了标准高斯分布:

四、transforms.CenterCrop() 由中心按指定大小切割

这个是从中心点开始裁剪的

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt
# transform=transforms.Compose([
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
# ])#记得打【】 不然要报错
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train')
plt.imshow(train[0][0])
plt.show()
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transforms.CenterCrop(224))
plt.imshow(train[0][0])
plt.show()
print(train[0][0])
print(type(train[0][0]))

原本的图像:

五、transforms.RandomCrop() 中心点的位置随便选取,指定大小切割

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt
# transform=transforms.Compose([
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
# ])#记得打【】 不然要报错
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train')
plt.imshow(train[0][0])
plt.show()
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transforms.RandomCrop(224))
plt.imshow(train[0][0])
plt.show()
print(train[0][0])
print(type(train[0][0]))

六、transforms.RandomResizedCrop() 随机长宽比裁剪为指定大小

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt
# transform=transforms.Compose([
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
# ])#记得打【】 不然要报错
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train')
plt.imshow(train[0][0])
plt.show()
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transforms.RandomResizedCrop(224))
plt.imshow(train[0][0])
plt.show()
print(train[0][0])
print(type(train[0][0]))

七、transforms.RandomResizedCrop()随机水平翻转

from torchvision import transforms,datasets
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt
# transform=transforms.Compose([
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
# ])#记得打【】 不然要报错
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train')
plt.imshow(train[0][0])
plt.show()
train = datasets.ImageFolder('E:/图像处理课题/cat_dog/train',
                             transform=transforms.RandomHorizontalFlip())
plt.imshow(train[0][0])
plt.show()
print(train[0][0])
print(type(train[0][0]))

显示全文