1. CNN
1.1. 设备参数
1.2. 代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import os
import cv2
import numpy as np
import time
import pandas as pd
import matplotlib.pyplot as plt# 自定义数据集类,用于加载医疗影像数据及其对应的标注坐标点数据
class MedicalImageDataset(Dataset):def __init__(self, image_dir, label_dir, transform=None):self.image_dir = image_dirself.label_dir = label_dir# 获取所有图像文件的路径列表,并按文件名排序self.image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.bmp')])# 获取所有标注文件(.txt)的路径列表,并按文件名排序self.label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.txt')])self.transform = transformdef __len__(self):# 返回数据集的大小,即图像的数量return len(self.image_files)def __getitem__(self, idx):image_path = self.image_files[idx]label_path = self.label_files[idx]# 使用OpenCV读取图像,默认读取格式为BGR,这里转换为RGB格式image = cv2.imread(image_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)with open(label_path, 'r') as f:lines = f.readlines()# 以下是修改后的坐标点数据解析代码,适配逗号分割的x、y坐标格式points = []for line in lines[:19]:parts = line.strip().split(',')x = float(parts[0].strip())y = float(parts[1].strip())points.append([x, y])points = np.array(points)if self.transform:image = self.transform(image)# 返回图像数据(可进行了变换)和对应的坐标点数据(转换为torch的Tensor类型)return image, torch.from_numpy(points).float()# 定义一个简单的卷积神经网络模型结构
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) # 第一个卷积层,输入通道3(RGB图像),输出通道16self.relu1 = nn.ReLU() # ReLU激活函数,增加非线性self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 最大池化层,进行下采样self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 第二个卷积层,输入16通道,输出32通道self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层,将经过卷积池化后的特征图展平后连接,输出维度为128self.fc1 = nn.Linear(32 * (2400 // 4) * (1935 // 4), 128)self.relu3 = nn.ReLU()# 最后一个全连接层,输出维度为19个坐标点 * 2(x、y坐标),用于预测关键点坐标self.fc2 = nn.Linear(128, 19 * 2)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(x.size(0), -1) # 将特征图展平x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 定义图像数据所在文件夹路径和标注数据所在文件夹路径(这里原代码有个小错误,两个路径都写成了Image,应修改)
image_dir = 'dataset/Image'
label_dir = 'dataset/Label'
total_dataset = MedicalImageDataset(image_dir, label_dir, transform=transforms.ToTensor())
# 定义训练集、验证集、测试集的大小
train_size = 280
valid_size = 80
test_size = 40# 使用torch的随机划分函数,将总数据集划分为训练集、验证集和测试集
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_size, valid_size, test_size])# 创建训练集的数据加载器,设置批次大小为2,并打乱数据顺序
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# 创建验证集的数据加载器,批次大小为2,不打乱顺序
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False)
# 创建测试集的数据加载器,批次大小为2,不打乱顺序
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)# 定义训练函数,用于执行模型的训练过程,并显示训练进度条及当前损失值
def train(model, train_loader, criterion, optimizer, device):model.train()train_loss = 0start_time = time.time() # 记录训练开始时间# 使用tqdm创建进度条,用于直观展示训练进度,同时显示当前损失情况progress_bar = tqdm(train_loader, desc="Training")for images, targets in progress_bar:images, targets = images.to(device), targets.to(device) # 将数据移到指定设备(GPU或CPU)上optimizer.zero_grad() # 梯度清零outputs = model(images) # 前向传播,获取模型输出# 计算损失,将模型输出和目标坐标点数据形状进行调整后传入损失函数(这里用均方误差损失)loss = criterion(outputs, targets.view(-1, 38))loss.backward() # 反向传播,计算梯度optimizer.step() # 根据梯度更新模型参数train_loss += loss.item() * images.size(0)progress_bar.set_postfix({"Loss": train_loss / len(train_loader.dataset)})end_time = time.time() # 记录训练结束时间print(f"Training epoch took {end_time - start_time:.2f} seconds") # 打印本轮训练耗时return train_loss / len(train_loader.dataset)# 定义验证函数,用于在验证集上评估模型性能,计算验证损失
def validate(model, valid_loader, criterion, device):model.eval()valid_loss = 0start_time = time.time() # 记录验证开始时间with torch.no_grad(): # 在验证阶段不需要计算梯度,节省内存并加速计算for images, targets in valid_loader:images, targets = images.to(device), targets.to(device)outputs = model(images)loss = criterion(outputs, targets.view(-1, 38))valid_loss += loss.item() * images.size(0)end_time = time.time() # 记录验证结束时间print(f"Validation took {end_time - start_time:.2f} seconds") # 打印验证耗时return valid_loss / len(valid_loader.dataset)# 定义测试函数,用于在测试集上测试训练好的模型,并返回相关评价指标(这里以均方误差MSE和平均绝对误差MAE为例)
def test(model, test_loader, device):model.eval()criterion_mse = nn.MSELoss()mse_loss = 0mae_loss = 0num_samples = 0with torch.no_grad():for images, targets in test_loader:images, targets = images.to(device), targets.to(device)outputs = model(images)# 计算均方误差mse_loss += criterion_mse(outputs, targets.view(-1, 38)).item() * images.size(0)# 计算平均绝对误差diff = torch.abs(outputs - targets.view(-1, 38))mae_loss += diff.sum().item()num_samples += targets.size(0)# 计算平均均方误差mse = mse_loss / num_samples# 计算平均绝对误差mae = mae_loss / (num_samples * 38) # 因为输出是19个坐标点 * 2(x、y坐标),共38个值return mse, maeif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 判断是否可用GPU,选择相应设备model = SimpleCNN().to(device) # 将模型实例移到选定的设备上criterion = nn.MSELoss() # 定义损失函数为均方误差损失函数optimizer = optim.Adam(model.parameters(), lr=0.001) # 定义优化器为Adam,设置学习率为0.001num_epochs = 50 # 定义训练的总轮数train_losses = [] # 用于记录每一轮训练的损失值valid_losses = [] # 用于记录每一轮验证的损失值for epoch in range(num_epochs):print(f"Starting epoch {epoch + 1}/{num_epochs}") # 打印当前开始的轮次信息train_loss = train(model, train_loader, criterion, optimizer, device) # 执行一轮训练并获取训练损失valid_loss = validate(model, valid_loader, criterion, device) # 在验证集上评估获取验证损失train_losses.append(train_loss)valid_losses.append(valid_loss)print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}')# 保存训练好的模型参数到文件torch.save(model.state_dict(), 'keypoint_detection_model.pth')# 将训练损失和验证损失保存到.xlsx文件data = {'Epoch': list(range(1, num_epochs + 1)),'Train Loss': train_losses,'Valid Loss': valid_losses}df = pd.DataFrame(data)df.to_excel('loss_data.xlsx', sheet_name='Losses', index=False)# 绘制训练过程的图像(训练损失和验证损失曲线)并保存为.jpg图片plt.plot(train_losses, label='Train Loss')plt.plot(valid_losses, label='Valid Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training and Validation Loss')plt.legend()plt.savefig('loss_curve.jpg')# 加载训练好的模型参数进行测试model.load_state_dict(torch.load('keypoint_detection_model.pth'))test_mse, test_mae = test(model, test_loader, device)print(f"Test MSE: {test_mse:.4f}")print(f"Test MAE: {test_mae:.4f}")
1.3. 训练过程
root@autodl-container-bdf4448313-9cd0857f:~/KeyPoint# python CNN.py
Starting epoch 1/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.40it/s, Loss=7.48e+4]
Training epoch took 31.83 seconds
Validation took 4.22 seconds
Epoch 1/50, Train Loss: 74786.7275, Valid Loss: 10508.4497
Starting epoch 2/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.50it/s, Loss=6.55e+3]
Training epoch took 31.12 seconds
Validation took 4.04 seconds
Epoch 2/50, Train Loss: 6553.7229, Valid Loss: 5802.1468
Starting epoch 3/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.47it/s, Loss=4.52e+3]
Training epoch took 31.33 seconds
Validation took 4.23 seconds
Epoch 3/50, Train Loss: 4518.1617, Valid Loss: 2897.5908
Starting epoch 4/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.42it/s, Loss=3.98e+3]
Training epoch took 31.68 seconds
Validation took 4.36 seconds
Epoch 4/50, Train Loss: 3982.8769, Valid Loss: 3127.8260
Starting epoch 5/50
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.44it/s, Loss=2.8e+3]
Training epoch took 31.54 seconds
Validation took 3.97 seconds
Epoch 5/50, Train Loss: 2801.6222, Valid Loss: 3637.8432
Starting epoch 6/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.48it/s, Loss=3.07e+3]
Training epoch took 31.23 seconds
Validation took 3.95 seconds
Epoch 6/50, Train Loss: 3067.7505, Valid Loss: 2115.4941
Starting epoch 7/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.50it/s, Loss=1.93e+3]
Training epoch took 31.09 seconds
Validation took 4.02 seconds
Epoch 7/50, Train Loss: 1934.6738, Valid Loss: 1724.4728
Starting epoch 8/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:30<00:00, 4.52it/s, Loss=1.76e+3]
Training epoch took 31.00 seconds
Validation took 4.00 seconds
Epoch 8/50, Train Loss: 1762.9244, Valid Loss: 2058.1415
Starting epoch 9/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.47it/s, Loss=2.27e+3]
Training epoch took 31.36 seconds
Validation took 4.29 seconds
Epoch 9/50, Train Loss: 2265.7019, Valid Loss: 2240.8221
Starting epoch 10/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.44it/s, Loss=1.85e+3]
Training epoch took 31.56 seconds
Validation took 4.01 seconds
Epoch 10/50, Train Loss: 1849.1750, Valid Loss: 2154.5502
Starting epoch 11/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.47it/s, Loss=2.01e+3]
Training epoch took 31.30 seconds
Validation took 4.10 seconds
Epoch 11/50, Train Loss: 2006.1661, Valid Loss: 2358.7412
Starting epoch 12/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.40it/s, Loss=1.59e+3]
Training epoch took 31.82 seconds
Validation took 4.32 seconds
Epoch 12/50, Train Loss: 1593.5506, Valid Loss: 2020.0459
Starting epoch 13/50
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.42it/s, Loss=2e+3]
Training epoch took 31.70 seconds
Validation took 4.14 seconds
Epoch 13/50, Train Loss: 2000.1736, Valid Loss: 2205.5771
Starting epoch 14/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.48it/s, Loss=1.46e+3]
Training epoch took 31.24 seconds
Validation took 4.07 seconds
Epoch 14/50, Train Loss: 1462.9312, Valid Loss: 1503.5471
Starting epoch 15/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.42it/s, Loss=1.83e+3]
Training epoch took 31.66 seconds
Validation took 4.24 seconds
Epoch 15/50, Train Loss: 1830.8163, Valid Loss: 1543.1940
Starting epoch 16/50
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.45it/s, Loss=1.6e+3]
Training epoch took 31.50 seconds
Validation took 4.09 seconds
Epoch 16/50, Train Loss: 1603.0216, Valid Loss: 1513.9206
Starting epoch 17/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.42it/s, Loss=1.06e+3]
Training epoch took 31.69 seconds
Validation took 4.18 seconds
Epoch 17/50, Train Loss: 1060.6837, Valid Loss: 1845.7085
Starting epoch 18/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.47it/s, Loss=1.03e+3]
Training epoch took 31.32 seconds
Validation took 4.07 seconds
Epoch 18/50, Train Loss: 1026.3920, Valid Loss: 1378.9620
Starting epoch 19/50
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.47it/s, Loss=1.4e+3]
Training epoch took 31.34 seconds
Validation took 3.94 seconds
Epoch 19/50, Train Loss: 1396.0000, Valid Loss: 1437.9113
Starting epoch 20/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.46it/s, Loss=1.22e+3]
Training epoch took 31.42 seconds
Validation took 4.04 seconds
Epoch 20/50, Train Loss: 1217.4062, Valid Loss: 1902.9985
Starting epoch 21/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.45it/s, Loss=1.22e+3]
Training epoch took 31.50 seconds
Validation took 4.16 seconds
Epoch 21/50, Train Loss: 1224.0013, Valid Loss: 1347.9988
Starting epoch 22/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.46it/s, Loss=2.14e+3]
Training epoch took 31.43 seconds
Validation took 4.21 seconds
Epoch 22/50, Train Loss: 2135.4654, Valid Loss: 1839.5696
Starting epoch 23/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.49it/s, Loss=1.45e+3]
Training epoch took 31.17 seconds
Validation took 3.87 seconds
Epoch 23/50, Train Loss: 1446.5424, Valid Loss: 1405.1328
Starting epoch 24/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:32<00:00, 4.36it/s, Loss=1.05e+3]
Training epoch took 32.14 seconds
Validation took 4.03 seconds
Epoch 24/50, Train Loss: 1049.4809, Valid Loss: 1279.9357
Starting epoch 25/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.48it/s, Loss=1.16e+3]
Training epoch took 31.27 seconds
Validation took 4.11 seconds
Epoch 25/50, Train Loss: 1156.5278, Valid Loss: 2135.3058
Starting epoch 26/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.46it/s, Loss=1.03e+3]
Training epoch took 31.38 seconds
Validation took 3.96 seconds
Epoch 26/50, Train Loss: 1025.4073, Valid Loss: 1808.6484
Starting epoch 27/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.46it/s, Loss=1.29e+3]
Training epoch took 31.42 seconds
Validation took 4.17 seconds
Epoch 27/50, Train Loss: 1291.6064, Valid Loss: 1668.1155
Starting epoch 28/50
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.45it/s, Loss=1.3e+3]
Training epoch took 31.46 seconds
Validation took 4.27 seconds
Epoch 28/50, Train Loss: 1301.2716, Valid Loss: 1434.8759
Starting epoch 29/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.41it/s, Loss=704]
Training epoch took 31.74 seconds
Validation took 4.33 seconds
Epoch 29/50, Train Loss: 703.5830, Valid Loss: 1480.0195
Starting epoch 30/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:32<00:00, 4.37it/s, Loss=777]
Training epoch took 32.05 seconds
Validation took 4.34 seconds
Epoch 30/50, Train Loss: 777.2662, Valid Loss: 1417.0025
Starting epoch 31/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.38it/s, Loss=769]
Training epoch took 31.96 seconds
Validation took 4.29 seconds
Epoch 31/50, Train Loss: 768.7920, Valid Loss: 1571.4930
Starting epoch 32/50
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:31<00:00, 4.45it/s, Loss=845]
Training epoch took 31.49 seconds
Validation to