3. train_encoder_decoder.py

ops/2024/10/21 3:26:54/

train_encoder_decoder.py

python">#__future__ 模块提供了一种方式,允许开发者在当前版本的 Python 中使用即将在将来版本中成为标准的功能和语法特性。此处为了确保代码同时兼容Python 2和Python 3版本中的print函数
from __future__ import print_function # 导入标准库和第三方库#os 是一个标准库模块,全称为 "operating system",用于提供与操作系统交互的功能。导入os.path模块用于处理文件和目录路径
import os.path 
#从os模块中导入了path子模块,可以直接使用path来调用os.path中的函数(上面的代码可以不用写)
from os import path #导入了sys模块,用于系统相关的参数和函数
import sys 
#导入了math模块,提供了数学运算函数
import math 
#导入了NumPy库,并使用np作为别名,NumPy是用于科学计算的基础库
import numpy as np 
#导入了Pandas库,并使用pd作为别名,Pandas是用于数据分析的强大库
import pandas as pd # 导入深度学习相关库
# keras 是一个机器学习和深度学习的库。backend 模块提供了对底层深度学习框架(如TensorFlow、Theano等)的访问接口,使得在不同的后端之间进行无缝切换变得更加容易。
import tensorflow as tf # 导入了Keras的backend模块,并使用K作为别名,用于访问后端引擎的函数
from keras import backend as K 
# Model 类在 Keras 中允许用户以函数式 API 的方式构建更为复杂的神经网络模型。通过使用 Model 类,可以自由地定义输入层、输出层和中间层,并将它们连接起来形成一个完整的模型。
from keras.models import Model # 1. LSTM (Long Short-Term Memory) 和 GRU (Gated Recurrent Unit) 都是循环神经网络 (RNN) 的变体,可以用来学习长期依赖关系,用于处理序列数据。
# 2. 在处理序列数据时,经常需要将某个层(如 Dense 层)应用于序列中的每一个时间步。TimeDistributed 可以将这样的层包装起来,使其能够处理整个序列。
# 3. 在函数式 API 中,可以使用 Input 来定义模型的输入节点,指定输入的形状和数据类型。
# 4. 在神经网络中,Dense 层是最基本的层之一,每个输入节点都与输出节点相连,用于学习数据中的非线性关系。
# 5. RepeatVector接受一个 2D 张量作为输入,并重复其内容 n 次生成一个3D张量,用于序列数据处理中的某些操作,例如将上下文向量重复多次以与每个时间步相关联。
from keras.layers import LSTM, GRU, TimeDistributed, Input, Dense, RepeatVector # 1. CSVLogger 是一个回调函数,用于将每个训练周期的性能指标(如损失和指标值)记录到 CSV 文件中。训练完成后,可以使用记录的数据进行分析和可视化,帮助了解模型在训练过程中的表现
# 2. EarlyStopping 是一个回调函数,用于在训练过程中根据验证集的表现来提前终止训练。它监控指定的性能指标(如验证损失)并在连续若干个周期内没有改善时停止训练,防止模型过拟合。
# 3. TerminateOnNaN 是一个回调函数,用于在训练过程中检测到损失函数返回 NaN(Not a Number)时提前终止训练。这可以帮助捕捉和处理训练过程中出现的数值问题,避免模型继续训练无效参数
from keras.callbacks import CSVLogger, EarlyStopping, TerminateOnNaN# regularizers 用于定义正则化项,减少模型的过拟合,通过向模型的损失函数添加惩罚项来限制模型参数的大小或者复杂度。
from keras import regularizers # Adam (Adaptive Moment Estimation) 优化器是基于随机梯度下降 (Stochastic Gradient Descent, SGD) 的方法之一,但它结合了动量优化和自适应学习率的特性: 
# 1. 动量(Momentum):类似于经典的随机梯度下降中的动量项,Adam会在更新参数时考虑上一步梯度的指数加权平均值,以减少梯度更新的方差,从而加速收敛; 
# 2. 自适应学习率:Adam根据每个参数的梯度的一阶矩估计(均值)和二阶矩估计(方差)来自动调整学习率。这种自适应学习率的机制可以使得不同参数有不同的学习率,从而更有效地优化模型。
from keras.optimizers import Adam # 1. 假设有一个函数 func(a, b, c),通过 partial(func, 1) 可以创建一个新函数,相当于 func(1, b, c),其中 1 是已经固定的参数。
# 2. update_wrapper 是一个函数,用于更新后一个函数的元信息(比如文档字符串、函数名等)到前一个函数上
from functools import partial, update_wrapper 
def wrapped_partial(func, *args, **kwargs):partial_func = partial(func, *args, **kwargs)update_wrapper(partial_func, func)return partial_func# 这是一个自定义的损失函数,计算加权的均方误差(Mean Squared Error)
# y_true是真实值,y_pred是预测值,weights是权重
# axis=-1指定了在计算均值时应该沿着最内层的轴进行操作,即在每个样本或数据点上进行平均,而不是在整个批次或特征维度上进行平均
def weighted_mse(y_true, y_pred, weights):return K.mean(K.square(y_true - y_pred) * weights, axis=-1)# 这部分代码用于选择使用的GPU设备。它从命令行参数中获取一个整数值gpu,如果gpu小于3,则设置CUDA环境变量以指定使用的GPU设备
import os
gpu = int(sys.argv[-13])
if gpu < 3:os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152os.environ["CUDA_VISIBLE_DEVICES"]= "{}".format(gpu)from tensorflow.python.client import device_libprint(device_lib.list_local_devices())# 这部分代码获取了一系列命令行参数,并将它们分别赋值给变量 
# 这些参数包括dataname数据集名称、nb_batches训练的批次数量、nb_epochs训练周期数、lr学习率、penalty正则化惩罚、dr丢弃率、patience耐心(用于Early Stopping),n_hidden神经网络中隐藏层的数量,hidden_activation隐藏层激活函数
imp = sys.argv[-1]
T = sys.argv[-2]
t0 = sys.argv[-3]
dataname = sys.argv[-4] 
nb_batches = sys.argv[-5]
nb_epochs = sys.argv[-6]
lr = float(sys.argv[-7])
penalty = float(sys.argv[-8])
dr = float(sys.argv[-9])
patience = sys.argv[-10]
n_hidden = int(sys.argv[-11])
hidden_activation = sys.argv[-12]# results_directory 是一个字符串,表示将要创建的结果文件夹路径,dataname 是之前从命令行参数中获取的数据集名称。
# .format(dataname) 是字符串的格式化方法,它会将 dataname 变量的值插入到占位符 {} 的位置。
# 如果这个文件夹路径不存在,就使用 os.makedirs 函数创建它。这个路径通常用于存储训练模型的结果或者日志。
results_directory = 'results/encoder-decoder/{}'.format(dataname)
if not os.path.exists(results_directory):os.makedirs(results_directory)# 定义了一个函数 create_model,用于创建、编译和返回一个循环神经网络(RNN)模型
def create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation):""" creates, compiles and returns a RNN model @param nb_features: the number of features in the model"""# 这里定义了两个输入层:# 1. inputs 是一个形状为 (n_pre, nb_features) 的输入张量,用于模型的主输入;# 2. weights_tensor 是一个形状相同的张量,用于传递权重或其他需要的信息inputs = Input(shape=(n_pre, nb_features), name="Inputs")  weights_tensor = Input(shape=(n_pre, nb_features), name="Weights") # 编码器,这里使用了两个 LSTM 层: # lstm_1 的主要作用是将输入序列转换为一个语义上丰富的固定长度表示(即隐藏状态),并且该表示包含了输入序列的全部信息。这个固定长度的表示将作为解码器的输入,用于生成目标序列。# 1. n_hidden:指定 LSTM 层的隐藏单元数,决定了网络的记忆容量和复杂度。# 2. dropout=dr 和 recurrent_dropout=dr:分别指定了输入和循环 dropout 的比例,有助于防止过拟合。# 3. activation=hidden_activation:设置了 LSTM 单元的激活函数,这里是通过 hidden_activation 参数传递的。# 4. return_sequences=True:指定返回完整的输出序列,而不是只返回最后一个时间步的输出。这是为了将完整的输入序列信息编码成隐藏状态序列,以便后续的解码器使用。# lstm_2 是一个相同的 LSTM 层,但它只返回最后一个时间步的输出 lstm_1 = LSTM(n_hidden, dropout=dr, recurrent_dropout=dr, activation=hidden_activation, return_sequences=True, name='LSTM_1')(inputs) lstm_2 = LSTM(n_hidden, activation=hidden_activation, return_sequences=False, name='LSTM_2')(lstm_1) repeat = RepeatVector(n_post, name='Repeat')(lstm_2) # get the last output of the LSTM and repeats itgru_1 = GRU(n_hidden, activation=hidden_activation, return_sequences=True, name='Decoder')(repeat)  # Decoderoutput= TimeDistributed(Dense(output_dim, activation='linear', kernel_regularizer=regularizers.l2(penalty), name='Dense'), name='Outputs')(gru_1)model = Model([inputs, weights_tensor], output)# model.compile(optimizer=Adam(lr=lr), loss=cl) 对模型进行编译。# optimizer=Adam(lr=lr) 指定了优化器为 Adam,并设置了学习率为 lr。# loss=cl 指定了损失函数为 cl,即上面定义的加权均方误差函数。cl = wrapped_partial(weighted_mse, weights=weights_tensor)model.compile(optimizer=Adam(lr=lr), loss=cl)print(model.summary()) return modeldef train_model(model, dataX, dataY, weights, nb_epoches, nb_batches):# Prepare model checkpoints and callbacksstopping = EarlyStopping(monitor='val_loss', patience=int(patience), min_delta=0, verbose=1, mode='min', restore_best_weights=True)csv_logger = CSVLogger('results/encoder-decoder/{}/training_log_{}_{}_{}_{}_{}_{}_{}_{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), separator=',', append=False)terminate = TerminateOnNaN()# 训练过程中会生成一个 history 对象,其中包含了训练过程中的损失和指标等信息,但并没有直接输出最终的参数值history = model.fit(x=[dataX,weights], y=dataY, batch_size=nb_batches, verbose=1,epochs=nb_epoches, callbacks=[stopping,csv_logger,terminate],validation_split=0.2)def test_model():n_post = int(1)n_pre =int(t0)-1seq_len = int(T)wx = np.array(pd.read_csv("data/{}-wx-{}.csv".format(dataname,imp)))print('raw wx shape', wx.shape)  wXC = []for i in range(seq_len-n_pre-n_post):wXC.append(wx[i:i+n_pre]) wXC = np.array(wXC)print('wXC shape:', wXC.shape)x = np.array(pd.read_csv("data/{}-x-{}.csv".format(dataname,imp)))print('raw x shape', x.shape) dXC, dYC = [], []for i in range(seq_len-n_pre-n_post):dXC.append(x[i:i+n_pre])dYC.append(x[i+n_pre:i+n_pre+n_post])dataXC = np.array(dXC)dataYC = np.array(dYC)print('dataXC shape:', dataXC.shape)print('dataYC shape:', dataYC.shape)nb_features = dataXC.shape[2]output_dim = dataYC.shape[2]# create and fit the encoder-decoder networkprint('creating model...')model = create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation)train_model(model, dataXC, dataYC, wXC, int(nb_epochs), int(nb_batches))# now testprint('Generate predictions on full training set')preds_train = model.predict([dataXC,wXC], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_train.shape)preds_train = np.squeeze(preds_train)print('predictions shape (squeezed)=', preds_train.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_train, delimiter=",")print('Generate predictions on test set')wy = np.array(pd.read_csv("data/{}-wy-{}.csv".format(dataname,imp)))print('raw wy shape', wy.shape)  wY = []for i in range(seq_len-n_pre-n_post):wY.append(wy[i:i+n_pre]) # weights for outputswXT = np.array(wY)print('wXT shape:', wXT.shape)y = np.array(pd.read_csv("data/{}-y-{}.csv".format(dataname,imp)))print('raw y shape', y.shape)  dXT = []for i in range(seq_len-n_pre-n_post):dXT.append(y[i:i+n_pre]) # treated is inputdataXT = np.array(dXT)print('dataXT shape:', dataXT.shape)preds_test = model.predict([dataXT, wXT], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_test.shape)preds_test = np.squeeze(preds_test)print('predictions shape (squeezed)=', preds_test.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_test, delimiter=",")def main():test_model()return 1if __name__ == "__main__":main()

http://www.ppmy.cn/ops/55836.html

相关文章

Java并发编程-AQS详解及案例实战(下篇)

文章目录 读写锁互斥:基于AQS的state二进制高低16位完成互斥判断`state`变量的位分配读写锁互斥判断代码实现总结释放写锁的源码剖析以及对AQS队列唤醒阻塞线程的过程释放写锁的源码AQS的`release`方法唤醒等待线程总结基于CAS实现多线程并发同时只有一个可以加读锁使用CAS实现…

elementui中@click短时间内多次触发,@click重复点击,做不允许重复点击处理

click快速点击&#xff0c;发生多次触发 2.代码示例&#xff1a; //html<el-button :loading"submitLoading" type"primary" click"submitForm">确 定</el-button>data() {return {submitLoading:false,}}//方法/** 提交按钮 */sub…

2^k进制数(对每部分代码详解)

2^k进制数 题目描述 设 r r r 是个 2 k 2^k 2k 进制数&#xff0c;并满足以下条件&#xff1a; r r r 至少是个 2 2 2 位的 2 k 2^k 2k 进制数。 作为 2 k 2^k 2k 进制数&#xff0c;除最后一位外&#xff0c; r r r 的每一位严格小于它右边相邻的那一位。 将 r r r …

PIL,OpenCV,Pytorch处理图像时的通道顺序(颜色,长宽深)

项目颜色通道顺序长宽通道顺序数据类型取值范围PILRGBHWCndarray0-255 (byte)OpenCVBGRHWCndarray0-255 (byte)PyTorchRGB/BGR (取决于如何读取)(N)CHWtensor0-1 (float, 标准化后); 0-255 (int, 未标准化) 注意以下几点&#xff1a; 颜色通道顺序&#xff1a;PIL默认使用RGB顺…

SQLite 附加数据库

SQLite 附加数据库 SQLite 是一种轻量级的数据库管理系统,因其小巧、快速和易于使用而广受欢迎。在 SQLite 中,可以将多个数据库文件附加到单个数据库连接中,从而允许用户在不同的数据库之间轻松切换和操作数据。本文将详细介绍如何在 SQLite 中附加数据库,并探讨其使用场…

网页封装APP:让您的网站变身移动应用

网页封装APP&#xff1a;让您的网站变身移动应用 随着移动设备的普及&#xff0c;越来越多的人开始使用移动设备浏览网站。但是&#xff0c;传统的网站设计并不适合移动设备的屏幕尺寸和交互方式&#xff0c;这导致了用户体验不佳和流失。 有没有办法让您的网站变身移动应用&…

qt 如果把像素点数据变成一个图片

1.概要 图像的本质是什么&#xff0c;就是一个个的像素点&#xff0c;对与显示器来说就是一个二维数组。无论多复杂的图片&#xff0c;对于显示器来说就是一个二维数组。 2.代码 #include "widget.h"#include <QApplication> #include <QImage> #incl…

HNU电子测试平台与工具2_《计算机串口使用与测量》

&#xff08;这个有留word哈哈&#xff09; 4.1 4.2 Linux 操作系统平台 一、实验目的 了解 Linux 系统文件系统的基本组织了解 Linux 基本的多用户权限系统熟练使用 ls、cd、cat、more、sudo、gcc、vim 等基本命令会使用 ls 和 chmod 命令查看和修改文件权限 二、实…