深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通

news/2025/1/20 2:53:16/

大家好,我是微学AI,今天给大家介绍一下深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通。本文将实现了一个万能数据生成类的编写,并使用PyTorch框架训练CNN、RNN和GNN模型。

目录:
1.背景介绍
2.依赖库介绍
3.万能的数据生成器介绍
4.CNN,RNN,GNN模型搭建
5.数据生成与模型训练
6.训练结果与总结
在这里插入图片描述

1.背景介绍

在人工智能模型训练过程中,我们需要进行一些实验、测试或调试,我们可能需要一个具有特定形状和数量的数据集来验证我们的算法或模型。通过构建一个万能的数据生成器,我们可以灵活地生成各种形状和大小的数据集,无需手动制作和准备真实数据集。

其次,数据生成器还可以用于探索数据集的性质和特征。通过生成具有特定分布、特征或规律的数据集,我们可以更深入地了解数据之间的关系、特征之间的相互影响以及数据的结构等。这对于数据预处理、特征工程和模型选择都非常有帮助。

此外,数据生成器还可以用于实现数据增强技术。数据增强是指通过对原始数据进行一系列变换或扰动来生成新的训练样本,以增加训练数据的多样性和泛化能力。通过构建一个万能的数据生成器,我们可以定义各种数据增强方法,并在训练过程中动态地生成增强后的样本,从而提高模型的稳健性和可靠性。

2.依赖库介绍

首先,我们需要引入以下依赖库:

  • torch:PyTorch框架
  • torch.optim:torch.optim是PyTorch框架中的一个模块,用于优化模型的参数。它提供了各种优化算法,如随机梯度下降(SGD)、Adam、Adagrad等。通过选择适当的优化算法和调整参数,可以使模型在训练过程中更好地收敛并获得更好的性能。
  • torch.utils.data:torch.utils.data是PyTorch框架中的一个模块,用于处理数据集的工具类。它提供了一些常用的数据处理操作,如数据加载、批量处理、数据迭代和数据转换等。通过使用torch.utils.data,可以方便地将数据集加载到模型中进行训练,并且能够灵活地处理不同格式的数据。
  • numpy:numpy是一个Python库,主要用于进行数值计算和科学计算。它提供了多维数组对象(ndarray)和一系列用于操作数组的函数。numpy可以高效地进行数值运算,并且支持广播(broadcasting)和向量化操作,因此在科学计算、数据分析和机器学习等领域都得到广泛应用。在PyTorch中,numpy可以与torch.Tensor进行无缝的转换,方便进行数据的处理和转换。

3.万能的数据生成器介绍

首先我们需要定义了一个名为UniversalDataset的数据集类,用于生成具有特定形状和数量的数据和标签。

在类的初始化方法__init__中,我们传入了三个参数:data_shape表示数据的形状(一个元组),target_shape表示标签的形状(一个元组),num_samples表示数据集中样本的数量。通过这三个参数生成数据。

接着,我们实现了__len__方法,该方法返回数据集中样本的数量,即num_samples。

再定义__getitem__方法,该方法根据索引idx返回数据集中索引对应的数据和标签。在这个方法中,我们首先创建了一个与data_shape相同形状的全零张量data,以及一个与target_shape相同形状的全零张量target。

然后,我们分别计算了数据和标签的维度,即data_dims和target_dims。

本文使用torch.linspace函数在0和1之间生成长度为data_dim_size的等间隔数据范围data_range,并通过reshape方法将其重新塑形为data_shape_expanded形状的张量。然后,我们将这个塑形后的数据范围加到数据张量data上。

我们对标签也进行了类似的操作,生成了一个有规律的标签张量target。

最后,我们返回了数据张量data和标签张量target作为这个索引对应的样本。

通过这个类,我们可以根据需要生成具有指定形状和数量的数据集,并且数据和标签都是有规律的,方便进行后续的训练和评估。

import torch
from torch import nn
from torch.utils.data import DataLoader, Datasetclass UniversalDataset(Dataset):def __init__(self, data_shape, target_shape, num_samples):self.data_shape = data_shapeself.target_shape = target_shapeself.num_samples = num_samplesdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 生成数据和标签data = torch.zeros(self.data_shape)target = torch.zeros(self.target_shape)# 计算数据和标签的维度data_dims = len(self.data_shape)target_dims = len(self.target_shape)# 生成有规律的数据和标签for dim in range(data_dims):data_dim_size = self.data_shape[dim]data_range = torch.linspace(0, 1, data_dim_size)data_shape_expanded = [1] * data_dimsdata_shape_expanded[dim] = data_dim_sizedata += data_range.reshape(data_shape_expanded)for dim in range(target_dims):target_dim_size = self.target_shape[dim]target_range = torch.linspace(0, 1, target_dim_size)target_shape_expanded = [1] * target_dimstarget_shape_expanded[dim] = target_dim_sizetarget += target_range.reshape(target_shape_expanded)return data, target

4.CNN,RNN,GNN模型搭建

class CNNModel(nn.Module):def __init__(self, input_shape):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(input_shape[0], 16, kernel_size=3, stride=1, padding=1)self.fc = nn.Linear(16 * (input_shape[1] // 2) * (input_shape[2] // 2), 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = self.fc(x)return xclass RNNModel(nn.Module):def __init__(self, input_shape):super(RNNModel, self).__init__()self.rnn = nn.RNN(input_shape[1], 64, batch_first=True)self.fc = nn.Linear(64, 10)def forward(self, x):_, h_n = self.rnn(x)x = self.fc(h_n.squeeze(0))return xclass GNNModel(nn.Module):def __init__(self, input_shape):super(GNNModel, self).__init__()self.fc1 = nn.Linear(input_shape[1], 32)self.fc2 = nn.Linear(32, 10)def forward(self, x):x = torch.mean(x, dim=1)x = self.fc1(x)x = nn.functional.relu(x)x = self.fc2(x)return x

5.数据生成与模型训练

# 定义训练函数
def train(model, dataloader, criterion, optimizer):running_loss = 0.0correct = 0total = 0for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels.argmax(dim=1)).sum().item()loss.backward()optimizer.step()running_loss += loss.item()epoch_loss = running_loss / len(dataloader)epoch_acc = correct / totalreturn epoch_loss, epoch_acc# 设置参数
data_shape_cnn = (3, 32, 32)  # (channels, height, width)
target_shape = (10,)
num_samples = 1000
batch_size = 32
learning_rate = 0.001
num_epochs = 10# 创建数据集和数据加载器
dataset = UniversalDataset(data_shape_cnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建CNN模型、优化器和损失函数
cnn_model = CNNModel(data_shape_cnn)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=learning_rate)
cnn_criterion = nn.CrossEntropyLoss()print('CNN模型训练:')
# 训练CNN模型
for epoch in range(num_epochs):cnn_loss, cnn_acc = train(cnn_model, dataloader, cnn_criterion, cnn_optimizer)print(f'CNN - Epoch {epoch+1}/{num_epochs}, Loss: {cnn_loss:.4f}, Accuracy: {cnn_acc:.4f}')# 重新创建数据集和数据加载器
data_shape_rnn = (20, 32)  # (sequence_length, input_size, hidden_size)
dataset = UniversalDataset(data_shape_rnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建RNN模型、优化器和损失函数
rnn_model = RNNModel(data_shape_rnn)
rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate)
rnn_criterion = nn.CrossEntropyLoss()print('RNN模型训练:')
# 训练RNN模型
for epoch in range(num_epochs):rnn_loss, rnn_acc = train(rnn_model, dataloader, rnn_criterion, rnn_optimizer)print(f'RNN - Epoch {epoch+1}/{num_epochs}, Loss: {rnn_loss:.4f}, Accuracy: {rnn_acc:.4f}')# 重新创建数据集和数据加载器
data_shape_gnn = (10, 100)  # (num_nodes, node_features)
dataset = UniversalDataset(data_shape_gnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建GNN模型、优化器和损失函数
gnn_model = GNNModel(data_shape_gnn)
gnn_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=learning_rate)
gnn_criterion = nn.CrossEntropyLoss()print('GNN模型训练:')
# 训练GNN模型
for epoch in range(num_epochs):gnn_loss, gnn_acc = train(gnn_model, dataloader, gnn_criterion, gnn_optimizer)print(f'GNN - Epoch {epoch+1}/{num_epochs}, Loss: {gnn_loss:.4f}, Accuracy: {gnn_acc:.4f}')

6.训练结果与总结

运行结果:

CNN模型训练:
CNN - Epoch 1/10, Loss: 10.4031, Accuracy: 0.5840
CNN - Epoch 2/10, Loss: 10.2561, Accuracy: 1.0000
CNN - Epoch 3/10, Loss: 10.2503, Accuracy: 1.0000
CNN - Epoch 4/10, Loss: 10.2496, Accuracy: 1.0000
CNN - Epoch 5/10, Loss: 10.2495, Accuracy: 1.0000
CNN - Epoch 6/10, Loss: 10.2494, Accuracy: 1.0000
CNN - Epoch 7/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 8/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 9/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 10/10, Loss: 10.2493, Accuracy: 1.0000
RNN模型训练:
RNN - Epoch 1/10, Loss: 10.3851, Accuracy: 0.9680
RNN - Epoch 2/10, Loss: 10.2606, Accuracy: 1.0000
RNN - Epoch 3/10, Loss: 10.2551, Accuracy: 1.0000
RNN - Epoch 4/10, Loss: 10.2531, Accuracy: 1.0000
RNN - Epoch 5/10, Loss: 10.2520, Accuracy: 1.0000
RNN - Epoch 6/10, Loss: 10.2513, Accuracy: 1.0000
RNN - Epoch 7/10, Loss: 10.2509, Accuracy: 1.0000
RNN - Epoch 8/10, Loss: 10.2506, Accuracy: 1.0000
RNN - Epoch 9/10, Loss: 10.2504, Accuracy: 1.0000
RNN - Epoch 10/10, Loss: 10.2502, Accuracy: 1.0000
GNN模型训练:
GNN - Epoch 1/10, Loss: 10.9591, Accuracy: 0.0400
GNN - Epoch 2/10, Loss: 10.3914, Accuracy: 1.0000
GNN - Epoch 3/10, Loss: 10.2818, Accuracy: 1.0000
GNN - Epoch 4/10, Loss: 10.2635, Accuracy: 1.0000
GNN - Epoch 5/10, Loss: 10.2569, Accuracy: 1.0000
GNN - Epoch 6/10, Loss: 10.2539, Accuracy: 1.0000
GNN - Epoch 7/10, Loss: 10.2524, Accuracy: 1.0000
GNN - Epoch 8/10, Loss: 10.2515, Accuracy: 1.0000
GNN - Epoch 9/10, Loss: 10.2509, Accuracy: 1.0000
GNN - Epoch 10/10, Loss: 10.2505, Accuracy: 1.0000

本文主要介绍了如何创建一个万能的数据生成类,可以根据输入的形状参数生成不同形状的数据。然后,将生成的数据和标签输入到CNN、RNN和GNN模型中进行训练,并打印出损失值和准确率。后续我们可以根据实际应用中可能需要根据具体任务做更多修改和扩展。


http://www.ppmy.cn/news/711261.html

相关文章

论开学第三个月干了点啥

在开学第二个月的最后一天也是光棍节,先把这个坑开了. 要开始给自己剪枝了. 11.12 早上艰难起床,买了点面包就去上线代了.线代体验不是很好....然后程设讲指针也没太听.不喜欢malloc的语法.没有new和delete好看.在程设的时候尝试鼓捣大作业结果好像不是很好鼓捣.只是换了换背景…

由争议拼多多之货找人想到的 BlockChain Storage 之5、区块链存储 - 存储供需的智能匹配...

今天这篇文章会涉及到比特币披萨节,拼多多,央行姚前,算法经济,拆借,保险基金 …… 在系列之四《SDS之BlockChain Storage系列:4、为什说区块链存储是下一个热点(下)& 互联网与物…

微信APP分析报告

每一个好的产品,之所以被大家所认可,肯定是有它的原因的,也许您会说“时势造英雄”,但我想说,如果产品本身不给力,也可以被“后来者居上”。现在是互联网时代,产品本身不存在秘密而言,而产品的商业模式是企业生存发展之关键所在,其每个优秀产品的商业模式是很难被复制…

大数据分析技术方案

转自 lWX471878的博客 http://xinsheng.huawei.com/cn/blog/detail_80005.html 一. 目标 现在已经进入大数据时代, 数据是无缝连接网络世界与物理世界的DNA。发现数据DNA、重组数据DNA是人类不断认识、探索、实践大数据的持续过程。…

牛逼的人很早就开始牛逼了

作者:拾遗君 来源:拾遗(ID:shiyi201633) “牛逼的人很早就开始牛逼了。”这话很有道理。虽然有“大器晚成”一说,但罗马绝对不是一天建成的。能够“大器晚成”的人,往往是一个一直很牛逼的人。 …

20194709 - 第一次作业-博客初体验

第一章术语 源程序软件服务软件架构源代码管理配置管理软件测试需求分析软件维护服务运营用户体验编译源代码错误代号程序设计语言996BUG P18-2 在第四章 4.5.2 结对编程,书中写到「一对程序员肩并肩,平等地,互补的进行开发工作」有关于独立开…

Linus Torvalds 传记

转载自:http://www.chenjunlu.com/2014/07/linus-torvalds-biography/ 在很久很久以前,我写过一个 Linus 的系列,当时收获了一箩筐的好评,我一边清点战果,一边手抚刀锋信誓旦旦的要把这个系列一口气写完,然…

O’Reilly精品图书系列:编写可读代码的艺术].(鲍斯维尔等).尹哲等

O’Reilly精品图书系列:编写可读代码的艺术].(鲍斯维尔等).尹哲等 A first Course in Logic An Introduction To Model Theory Proof Theory Computability And Complexity - Shawn Hedman.djvu: http://www.t00y.com/file/59442904Advanced Complexity Theory Lctn…