深度学习-医学影像诊断

devtools/2025/2/14 4:52:06/

以下以使用深度学习进行医学影像(如 X 光片)的肺炎诊断为例,为你展示基于 PyTorch 框架的代码实现。我们将构建一个简单的卷积神经网络(CNN)模型,使用公开的肺炎 X 光影像数据集进行训练和评估。

1. 安装必要的库

pip install torch torchvision numpy matplotlib pandas

2. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset = datasets.ImageFolder(root='path/to/train_data', transform=transform)
test_dataset = datasets.ImageFolder(root='path/to/test_data', transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 定义简单的 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
train_losses = []
for epoch in range(num_epochs):running_loss = 0.0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()epoch_loss = running_loss / len(train_loader)train_losses.append(epoch_loss)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')# 绘制训练损失曲线
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()# 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

3. 代码解释

  • 数据预处理

    • 使用 transforms.Compose 定义了一系列的数据预处理操作,包括调整图像大小、转换为张量和归一化。
    • transforms.Resize((224, 224)) 将图像调整为 224x224 大小。
    • transforms.ToTensor() 将图像转换为张量。
    • transforms.Normalize 对图像进行归一化处理。
  • 数据集加载

    • 使用 datasets.ImageFolder 加载训练集和测试集,需要将 path/to/train_datapath/to/test_data 替换为实际的数据集路径。
    • DataLoader 用于创建数据加载器,方便批量加载数据。
  • 模型定义

    • SimpleCNN 类定义了一个简单的卷积神经网络模型,包含两个卷积层、两个池化层和两个全连接层。
  • 训练过程

    • 使用 nn.CrossEntropyLoss 作为损失函数,optim.Adam 作为优化器。
    • 在每个 epoch 中,遍历训练数据,计算损失并进行反向传播和参数更新。
  • 模型评估

    • 将模型设置为评估模式(model.eval()),在测试集上进行预测,并计算准确率。

4. 注意事项

  • 数据集:你需要准备合适的医学影像数据集,并将其按照训练集和测试集进行划分,每个类别放在不同的文件夹中。
  • 模型复杂度:这里的 SimpleCNN 是一个简单的模型,在实际应用中,可能需要使用更复杂的预训练模型(如 ResNet、DenseNet 等)来提高诊断准确率。
  • 计算资源:训练深度学习模型需要一定的计算资源,建议在 GPU 上运行以提高训练速度。可以使用 torch.cuda.is_available() 检查是否有可用的 GPU,并将模型和数据移动到 GPU 上进行训练。例如:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
images, labels = images.to(device), labels.to(device)

如果你有其他具体需求,如使用不同的模型架构、处理不同类型的医学影像等,可以进一步调整代码。


http://www.ppmy.cn/devtools/158676.html

相关文章

Conda 虚拟环境与 venv、virtualenv、pipenv 的对比

1. 引言 在 Python 开发中,虚拟环境是解决不同项目依赖冲突的关键工具。Python 提供了多种虚拟环境管理工具,包括 Conda、venv、virtualenv 和 pipenv。每种工具都有其独特的特点和适用场景。本篇博客将简要对比这些工具,帮助你选择最适合的…

【闲谈集】学网络应用开发好还是学网络安全好?

互联网各领域资料分享专区(不定期更新): Sheet 前言 网络应用开发主要涉及创建网站、应用程序,前端后端这些技术栈,而网络安全则是保护系统、网络免受攻击,涉及渗透测试、漏洞分析等。 喜欢构建东西,可能更适合开发&…

分布式系统知识点总结

一、一致性协议 ¥1. CAP理论 CAP理论是分布式系统设计中的一套指导原则,它指出在网络分区的情况下,一个分布式系统最多只能同时满足以下三点中的两点: 一致性(Consistency):所有节点在同一时…

【3.Git与Github的历史和区别】

目录 Git的历史和Github的区别本质和功能 Git的历史和Github的区别 Git是由Linux内核的创造者Linus Torvalds于2005年创建的。当时,Linux内核开源项目使用BitKeeper作为版本控制系统,但2005年BitKeeper的商业公司终止了与Linux社区的合作,收…

希尔排序(C#)

目录 1 什么是希尔排序 2 算法步骤 3 代码实现 1 什么是希尔排序 希尔排序是插入排序的一种更高效的改进版本,也称为缩小增量排序。它的基本思想是将原始数据分成多个子序列来进行插入排序,通过逐渐缩小子序列的间隔(增量)&a…

碰一碰发视频源码技术开发,支持OEM

一、引言 在当今数字化信息快速传播的时代,碰一碰发视频这种便捷的数据交互方式正逐渐走进人们的生活。从技术实现角度来看,其后台开发逻辑是确保整个功能稳定运行的关键。本文将深入剖析碰一碰发视频后台开发的核心逻辑,为开发者提供技术参…

DeepSeek 助力 Vue 开发:打造丝滑的进度条

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…

34.Qt使用回调函数

新建Qt项目&#xff0c;添加回调函数所在的类Callback 项目文件如下所示 Callback.h代码 #ifndef CALLBACK_H #define CALLBACK_H#include <QObject>class Callback : public QObject {Q_OBJECT public:explicit Callback(QObject *parent nullptr);public:static voi…