pytorch中的transform用法

ops/2024/11/18 20:06:59/

在 PyTorch 中,transform 主要用于数据预处理和数据增强,尤其在计算机视觉任务中,通过 torchvision.transforms 模块进行图像的变换。transforms 可以对图像进行一系列操作,如裁剪、旋转、缩放、归一化等,以增强数据集的多样性,并提高模型的泛化能力。

1. torchvision.transforms 模块概述

torchvision.transforms 是 PyTorch 提供的一个图像转换工具,它包含一系列的变换操作。常见的转换操作包括:

  • 图像大小调整(Resize)
  • 裁剪(Crop)
  • 图像翻转(Flip)
  • 颜色调整(Color Jitter)
  • 图像归一化(Normalization)
  • 转换为张量(ToTensor)

2. 常用的 transforms 操作

python">from torchvision import transforms
1) transforms.ToTensor()

将图像转换为 PyTorch 张量(Tensor),并且自动将图像的像素值缩放到 [0, 1] 的范围内。

python">transform = transforms.ToTensor()
image_tensor = transform(image)
2) transforms.Resize()

调整图像的大小,可以指定一个单一的大小或宽度/高度。

python">transform = transforms.Resize((224, 224))  # 调整为 224x224 的尺寸
image_resized = transform(image)
3) transforms.CenterCrop()transforms.RandomCrop()

CenterCrop 会从图像的中心裁剪出指定大小的区域;RandomCrop 会随机裁剪出一个指定大小的区域。

python">transform = transforms.CenterCrop(224)  # 从中心裁剪出 224x224 的区域
image_cropped = transform(image)# 或者使用随机裁剪
transform = transforms.RandomCrop(224)
image_random_cropped = transform(image)
4) transforms.RandomHorizontalFlip()transforms.RandomVerticalFlip()

进行水平或垂直的随机翻转。

python">transform = transforms.RandomHorizontalFlip(p=0.5)  # 50% 的概率进行水平翻转
image_flipped = transform(image)
5) transforms.Normalize()

对图像的每个通道进行归一化。通常用来调整图像的颜色通道,使其符合模型训练时的要求。

python">transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_normalized = transform(image_tensor)  # 对每个通道进行归一化
6) transforms.ColorJitter()

随机调整图像的亮度、对比度、饱和度和色相。适用于增强数据集的多样性。

python">transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
image_jittered = transform(image)
7) transforms.RandomRotation()

对图像进行随机旋转。

python">transform = transforms.RandomRotation(30)  # 随机旋转 -30 到 30 度之间
image_rotated = transform(image)

3. 多种 transforms 组合使用

通常,我们会将多个变换操作组合成一个 Compose,使得一个图像依次经过多个变换步骤。

python">transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])image_transformed = transform(image)

上面的代码会将图像:

  1. 调整为 256x256
  2. 随机裁剪为 224x224
  3. 进行水平翻转
  4. 转换为张量
  5. 归一化图像

4. 结合 Dataset 使用 transforms

通常,我们会将 transformstorch.utils.data.Datasettorch.utils.data.DataLoader 结合使用,用于训练过程中的数据预处理。

python">from torchvision import datasets
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])train_dataset = datasets.ImageFolder(root='path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

在上面的代码中,ImageFolder 是一个 PyTorch 提供的通用图像数据集类,用于加载目录结构为类标签的图像数据。transform 用于对数据集中的每个图像进行预处理。

5. 自定义 transform

如果 torchvision.transforms 中的预定义操作不能满足需求,我们还可以自定义一个转换类。例如,如果你想为每张图片添加噪声:

python">from PIL import Image
import numpy as npclass AddGaussianNoise(object):def __init__(self, mean=0., std=1.):self.mean = meanself.std = stddef __call__(self, image):image = np.array(image)noise = np.random.normal(self.mean, self.std, image.shape)noisy_image = image + noisenoisy_image = np.clip(noisy_image, 0, 255)  # 保证像素值在 [0, 255] 范围内return Image.fromarray(noisy_image.astype(np.uint8))# 使用自定义转换
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),AddGaussianNoise(mean=0, std=0.1),  # 添加高斯噪声transforms.ToTensor(),
])image = Image.open('path_to_image.jpg')
transformed_image = transform(image)

总结

  • transforms 是 PyTorch 中处理图像数据的一组强大工具,适用于图像预处理和数据增强。
  • 通过 transforms.Compose() 可以组合多个转换操作。
  • ToTensor()Resize()RandomCrop()Normalize() 等是常用的转换。
  • 通过 DataLoader 可以高效地加载批量数据,并在训练过程中对每个样本应用转换。

http://www.ppmy.cn/ops/134780.html

相关文章

《InsCode AI IDE:编程新时代的引领者》

《InsCode AI IDE:编程新时代的引领者》 一、InsCode AI IDE 的诞生与亮相二、独特功能与优势(一)智能编程体验(二)多语言支持与功能迭代 三、实际应用与案例(一)游戏开发案例(二&am…

安全见闻1-5

涵盖了编程语言、软件程序类型、操作系统、网络通讯、硬件设备、web前后端、脚本语言、病毒种类、服务器程序、人工智能等基本知识,有助于全面了解计算机科学和网络技术的各个方面。 安全见闻1 1.编程语言简要概述 C语言:面向过程,适用于系统…

Ceph 中PG与PGP的概述

在Ceph分布式存储系统中,PG(Placement Group)和PGP(Placement Group for Placement purpose)是两个至关重要的概念,它们共同决定了数据在集群中的分布和复制方式。以下是关于Ceph中PG和PGP关系的详细解释&a…

无人机:科技改变生活的神奇力量

无人机,作为一种高科技产品,已经在我们的生活中发挥着越来越重要的作用。从军事侦察到民用拍摄,从农业监测到物流配送,无人机的应用领域正在迅速扩展。本文将为您详细介绍无人机的多种应用,帮助您更全面地了解这一现代…

SpringBoot整合FreeMarker生成word表格文件

SpringBoot整合FreeMarker生成word表格文件(使用FTL模板)_freemarker ftl模板-CSDN博客 Freemarker基本指令语法和集合指令语法SpringBoot整合FreeMarker生成word表格文件(使用FTL模板)_freemarker ftl模板-CSDN博客https://zhua…

基于单片机智能温室大棚监测系统

本设计以单片机为核心的智能温室大棚监测系统,用于监测大棚内的温湿度、土壤湿度、CO2浓度和光照强度。该系统以STM32F103C8T6芯片为核心控制单元,涵盖电源、按键、NB-IoT模块、显示屏模块、空气温湿度检测、土壤湿度检测、二氧化碳检测和光敏电阻等模块…

# JAVA中的Stream学习

JAVA中的Stream 1、Stream是什么? Stream 是 Java 8 引入的一个新的抽象层,用于处理数据集合。它可以让你以声明式的方式处理数据,类似于 SQL 语句的查询方式。 2、Stream能够做什么? 过滤:通过条件筛选数据。映射:转换数据…

Web前端之汉字排序、sort与localeCompare的介绍、编码顺序与字典顺序的区别

MENU 使用字典顺序对汉字进行排序(不支持多音字)编码顺序和字典顺序的区别sort与localeCompare的介绍 使用字典顺序对汉字进行排序(不支持多音字) 不使用拼音库,利用JavaScript的localeCompare方法直接按汉字的字典序排序。localeCompare可以在比较字符串时指定语言…