深度学习每周学习总结R2(RNN-天气预测)

server/2025/1/8 4:15:57/
  • 🍨 本文为🔗365天学习>深度学习训练营 中的学习记录博客R5中的内容,为了便于自己整理总结起名为R2
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 0. 总结
    • 1. RNN介绍
      • a. 什么是 RNN?
        • RNN 的一般应用场景
      • b. 传统 RNN 的基本结构
        • 关键特征
      • c. RNN 的优势与局限
        • 优势
        • 局限与改进
      • d. RNN 的常见变体:LSTM 和 GRU
        • LSTM (Long Short-Term Memory)
        • GRU (Gated Recurrent Unit)
      • e. RNN 的应用案例
      • f. RNN 在 PyTorch 中的实现方式
      • g. 如何更进一步学习 RNN?
      • h. 总结
    • 2. 数据导入
    • 3. 数据探索性分析
      • a. 数据相关性探索
      • b. 是否会下雨
      • c. 地理位置与下雨的关系
      • d. 湿度和压力对下雨的影响
      • e. 气温对下雨的影响
    • 4. 数据预处理
    • 5. 构建数据集
    • 6. 定义模型
    • 7. 初始化模型与优化器
    • 8. 训练函数
    • 9. 测试函数
    • 10. 执行训练
    • 11. 过程可视化

0. 总结

数据导入及处理部分:在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。

模型构建部分:RNN

设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。

1. RNN介绍

下面是对 RNN(Recurrent Neural Network) 的一个循序渐进、相对通俗的介绍,帮助你从原理上理解 RNN 的本质与应用,希望对你有所帮助。


a. 什么是 RNN?

RNN,全称 Recurrent Neural Network,即“循环神经网络”。它是一类专门处理序列数据神经网络模型,与传统的前馈网络(如全连接网络 MLP、卷积网络 CNN 等)最大的区别在于:

  • 序列性:RNN 可以在序列的时间步之间传递信息,具备“记忆”先前输入的能力。
  • 循环结构:在每一个时间步,网络都会基于当前输入上一时刻的隐藏状态来更新当前隐藏状态,然后输出结果。
RNN 的一般应用场景
  • 自然语言处理(NLP):如情感分析、文本分类、机器翻译、文本生成等。
  • 时间序列预测:如股票预测、温度预测、信号处理等。
  • 语音识别或合成:处理音频序列。

b. 传统 RNN 的基本结构

以下是一个最基础(经典版)的 RNN 结构示意:

  ┌───────┐      ┌───────┐      ┌───────┐ │x(t-1) │      │x(t)   │      │x(t+1) │  ← 输入序列└──┬────┘      └──┬────┘      └──┬────┘│              │              │┌─▼──────────────▼──────────────▼─────────────────────────┐│                    RNN 单元 (循环体)                     ││                                                        ││   h(t-1) ──┐   ┌─────────┐   ┌─────────┐               ││            │   │激活函数 f│   │激活函数 g│               ││ x(t), h(t-1) → │ 线性运算 → │ (如 tanh)  → h(t)          ││            │   └─────────┘   └─────────┘               │└────────────┴─────────────────────────────────────────────┘↑  通过时间传递(隐藏状态 h)
  • 输入序列:( x(1), x(2), …, x(T) )
  • 隐藏状态:( h(t) ) 表示网络在时间步 ( t ) 的内部记忆。
  • 更新公式(经典 RNN 的简单形式):
    [
    h(t) = \sigma(W_{hh} \cdot h(t-1) + W_{xh} \cdot x(t) + b_h)
    ]
    其中 (\sigma) 通常是一个非线性激活函数,如 (\tanh) 或 (\text{ReLU}) 等。
关键特征
  1. 循环(Recurrent)

    • RNN 通过将过去的隐藏状态 ( h(t-1) ) 反复输入到网络,与当前输入 ( x(t) ) 一起决策新的隐藏状态 ( h(t) )。因此它在时间序列上“循环”展开。
  2. 参数共享(Parameter Sharing)

    • 对于序列中每个时间步,RNN 使用相同的一组权重((W_{hh}, W_{xh}) 等),这与一般的多层感知器(MLP)不同,MLP 每一层都会有一组新的权重。
  3. 序列建模

    • 借助隐藏状态的更新,RNN 在一定程度上能够“记住”之前输入的信息,从而可以用来处理依赖于上下文或时间顺序的任务(如语言模型,每个单词与前面单词息息相关)。

c. RNN 的优势与局限

优势
  1. 适合序列数据:相比于传统的全连接网络,RNN 能够更好地处理变长的序列输入,捕捉序列中的时序依赖关系。
  2. 参数共享:节省模型参数,防止过度膨胀。
局限与改进
  1. 长期依赖问题:经典 RNN 里,随着序列长度增大,早期输入的信息往往无法传播到后面时间步,会导致梯度消失或梯度爆炸
  2. 训练效率:由于存在序列展开 + 反向传播(BPTT: Back Propagation Through Time)的特殊性,训练速度通常慢于并行度高的卷积网络。
  3. 改进模型
    • LSTM(Long Short-Term Memory)
    • GRU(Gated Recurrent Unit)
      这两种模型通过门控机制(忘记门、输入门、输出门等)来缓解或部分解决长期依赖问题,在实际中广泛使用。

d. RNN 的常见变体:LSTM 和 GRU

由于传统 RNN 在对长序列进行建模时,容易遗忘早期信息,为了解决这个问题,人们提出了带有 “门控” 机制的循环神经网络结构。其中最典型的就是 LSTMGRU

LSTM (Long Short-Term Memory)
  • 记忆单元(Cell state)和 门控机制(input gate、forget gate、output gate)来控制信息的流动,保留长期的梯度信息,从而缓解梯度消失问题。
  • 在很多 NLP 任务中,LSTM 大多表现优于传统 RNN。
GRU (Gated Recurrent Unit)
  • 结构上比 LSTM 更简化,只有 更新门重置门,虽然结构更简单,但也能保留一定的长期依赖能力。
  • 在某些任务中,GRU 的性能与 LSTM 不相上下,而且训练速度更快。

e. RNN 的应用案例

  1. 语言模型

    • 给定前面的单词,预测下一个单词;或给定一段前文,生成下一段文本。
    • 例如早期的机器翻译系统,输入序列是原语言单词,输出序列是翻译后的目标语言单词。
    • 现在更多使用了 Transformer 这种基于自注意力机制的模型,但 RNN 依然是重要的基石概念。
  2. 序列分类

    • 对一段文本或语音做分类,如情感分析(正向/负向)、语音识别(识别说的是哪一句话)等。
  3. 时间序列预测

    • 比如股票预测、流量预测、天气预测,通过过去若干时刻的数据预测未来走向。

f. RNN 在 PyTorch 中的实现方式

在 PyTorch 里,最常见的循环网络层包括:

  • nn.RNN:经典单层 RNN,可选激活函数 tanhReLU
  • nn.LSTM:LSTM 结构
  • nn.GRU:GRU 结构

输入通常需要形状 (batch_size, seq_len, input_size)(当 batch_first=True 时)。
输出需要自己选择:

  • 如果只需要最后一个时间步的输出,往往取 output[:, -1, :]
  • 如果需要所有时间步的输出(比如生成序列时),则直接使用 output
  • 训练时要记得将 hidden state(以及 cell state)正确地传递或重置。

g. 如何更进一步学习 RNN?

  1. 从小例子入手
    • 用 RNN 来解决简单的序列学习任务(例如正弦波预测、小规模字符级语言模型),查看网络是如何随时间迭代的。
  2. 阅读论文与教程
    • LSTM 的原始论文 (Hochreiter & Schmidhuber, 1997)
    • GRU (Cho et al., 2014)
    • 深入理解门控机制,体会为什么能让 RNN 更好地记住/遗忘信息。
  3. 与 Transformer 对比
    • 在大多数 NLP 任务上,目前已被 Transformer 结构占据主流,但 RNN 思想仍是许多研究的基础。理解 RNN 有助于理解注意力机制为什么行之有效。
  4. 深入到框架实现
    • 看 PyTorch 中 nn.RNNnn.LSTMnn.GRU 的源代码或官方文档,了解参数含义及前向、后向的具体计算流程。

h. 总结

  • 核心思想:RNN 可以“循环”地将过去的信息传递到现在,从而在一定程度上捕捉序列数据的依赖关系。
  • 传统 RNN 的问题:容易出现梯度消失或爆炸,难以捕捉长程依赖。
  • 常见改进:LSTM、GRU 等门控结构缓解了长期依赖难题,也成为 RNN 家族的主力。
  • 现今趋势:NLP 等领域更多使用 Transformer,但 RNN 在许多对序列长度不太长的场合依旧可以使用,而且对初学者理解神经网络的“记忆”能力非常有帮助。

如果你刚开始学习,可以:

  1. 多动手调试:写一些小规模 RNN 代码,训练简单的序列数据,观察 loss 和隐藏状态如何变化。
  2. 多画图:用纸笔画 RNN 在时序上的展开图,有助于理解反向传播的流程。
  3. 分门别类:清楚哪些任务用 LSTM/GRU,哪些任务需要 CNN 或 Transformer,知道各种模型的优势与局限。

2. 数据导入

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fimport copy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error , mean_absolute_percentage_error , mean_squared_error
data = pd.read_csv("./data/weatherAUS.csv")
df   = data.copy()
data.head()
DateLocationMinTempMaxTempRainfallEvaporationSunshineWindGustDirWindGustSpeedWindDir9am...Humidity9amHumidity3pmPressure9amPressure3pmCloud9amCloud3pmTemp9amTemp3pmRainTodayRainTomorrow
02008-12-01Albury13.422.90.6NaNNaNW44.0W...71.022.01007.71007.18.0NaN16.921.8NoNo
12008-12-02Albury7.425.10.0NaNNaNWNW44.0NNW...44.025.01010.61007.8NaNNaN17.224.3NoNo
22008-12-03Albury12.925.70.0NaNNaNWSW46.0W...38.030.01007.61008.7NaN2.021.023.2NoNo
32008-12-04Albury9.228.00.0NaNNaNNE24.0SE...45.016.01017.61012.8NaNNaN18.126.5NoNo
42008-12-05Albury17.532.31.0NaNNaNW41.0ENE...82.033.01010.81006.07.08.017.829.7NoNo

5 rows × 23 columns

data.describe()
MinTempMaxTempRainfallEvaporationSunshineWindGustSpeedWindSpeed9amWindSpeed3pmHumidity9amHumidity3pmPressure9amPressure3pmCloud9amCloud3pmTemp9amTemp3pm
count143975.000000144199.000000142199.00000082670.00000075625.000000135197.000000143693.000000142398.000000142806.000000140953.000000130395.00000130432.00000089572.00000086102.000000143693.000000141851.00000
mean12.19403423.2213482.3609185.4682327.61117840.03523014.04342618.66265768.88083151.5391161017.649941015.2558894.4474614.50993016.99063121.68339
std6.3984957.1190498.4780604.1937043.78548313.6070628.9153758.80980019.02916420.7959027.106537.0374142.8871592.7203576.4887536.93665
min-8.500000-4.8000000.0000000.0000000.0000006.0000000.0000000.0000000.0000000.000000980.50000977.1000000.0000000.000000-7.200000-5.40000
25%7.60000017.9000000.0000002.6000004.80000031.0000007.00000013.00000057.00000037.0000001012.900001010.4000001.0000002.00000012.30000016.60000
50%12.00000022.6000000.0000004.8000008.40000039.00000013.00000019.00000070.00000052.0000001017.600001015.2000005.0000005.00000016.70000021.10000
75%16.90000028.2000000.8000007.40000010.60000048.00000019.00000024.00000083.00000066.0000001022.400001020.0000007.0000007.00000021.60000026.40000
max33.90000048.100000371.000000145.00000014.500000135.000000130.00000087.000000100.000000100.0000001041.000001039.6000009.0000009.00000040.20000046.70000
data.dtypes
Date              object
Location          object
MinTemp          float64
MaxTemp          float64
Rainfall         float64
Evaporation      float64
Sunshine         float64
WindGustDir       object
WindGustSpeed    float64
WindDir9am        object
WindDir3pm        object
WindSpeed9am     float64
WindSpeed3pm     float64
Humidity9am      float64
Humidity3pm      float64
Pressure9am      float64
Pressure3pm      float64
Cloud9am         float64
Cloud3pm         float64
Temp9am          float64
Temp3pm          float64
RainToday         object
RainTomorrow      object
dtype: object

3. 数据探索性分析

#将数据转换为日期时间格式
data['Date'] = pd.to_datetime(data['Date'])data['year']  = data['Date'].dt.year
data['Month'] = data['Date'].dt.month
data['day']   = data['Date'].dt.daydata.head()
DateLocationMinTempMaxTempRainfallEvaporationSunshineWindGustDirWindGustSpeedWindDir9am...Pressure3pmCloud9amCloud3pmTemp9amTemp3pmRainTodayRainTomorrowyearMonthday
02008-12-01Albury13.422.90.6NaNNaNW44.0W...1007.18.0NaN16.921.8NoNo2008121
12008-12-02Albury7.425.10.0NaNNaNWNW44.0NNW...1007.8NaNNaN17.224.3NoNo2008122
22008-12-03Albury12.925.70.0NaNNaNWSW46.0W...1008.7NaN2.021.023.2NoNo2008123
32008-12-04Albury9.228.00.0NaNNaNNE24.0SE...1012.8NaNNaN18.126.5NoNo2008124
42008-12-05Albury17.532.31.0NaNNaNW41.0ENE...1006.07.08.017.829.7NoNo2008125

5 rows × 26 columns

data.drop('Date',axis=1,inplace=True)
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine','WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm','WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm','Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am','Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],dtype='object')

a. 数据相关性探索

plt.figure(figsize=(15,13))
# data.corr()表示了data中的两个变量之间的相关性
ax = sns.heatmap(data.corr(), square=True, annot=True, fmt='.2f')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)          
plt.show()

在这里插入图片描述


b. 是否会下雨

# 设置样式和调色板
sns.set(style="whitegrid", palette="Set2")# 创建一个 1 行 2 列的图像布局
fig, axes = plt.subplots(1, 2, figsize=(10, 4))  # 图形尺寸调大 (10, 4)# 图表标题样式
title_font = {'fontsize': 14, 'fontweight': 'bold', 'color': 'darkblue'}# 第一张图:RainTomorrow
sns.countplot(x='RainTomorrow', data=data, ax=axes[0], edgecolor='black')  # 添加边框
axes[0].set_title('Rain Tomorrow', fontdict=title_font)  # 设置标题
axes[0].set_xlabel('Will it Rain Tomorrow?', fontsize=12)  # X轴标签
axes[0].set_ylabel('Count', fontsize=12)  # Y轴标签
axes[0].tick_params(axis='x', labelsize=11)  # X轴刻度字体大小
axes[0].tick_params(axis='y', labelsize=11)  # Y轴刻度字体大小# 第二张图:RainToday
sns.countplot(x='RainToday', data=data, ax=axes[1], edgecolor='black')  # 添加边框
axes[1].set_title('Rain Today', fontdict=title_font)  # 设置标题
axes[1].set_xlabel('Did it Rain Today?', fontsize=12)  # X轴标签
axes[1].set_ylabel('Count', fontsize=12)  # Y轴标签
axes[1].tick_params(axis='x', labelsize=11)  # X轴刻度字体大小
axes[1].tick_params(axis='y', labelsize=11)  # Y轴刻度字体大小sns.despine()      # 去除图表顶部和右侧的边框
plt.tight_layout() # 调整布局,避免图形之间的重叠
plt.show()

在这里插入图片描述

x=pd.crosstab(data['RainTomorrow'],data['RainToday'])
x
RainTodayNoYes
RainTomorrow
No9272816858
Yes1660414597
y=x/x.transpose().sum().values.reshape(2,1)*100
y
RainTodayNoYes
RainTomorrow
No84.61664815.383352
Yes53.21624346.783757
  • 如果今天不下雨,那么明天下雨的机会 = 53.22%

  • 如果今天下雨明天下雨的机会 = 46.78%

y.plot(kind="bar",figsize=(4,3),color=['#006666','#d279a6']);

在这里插入图片描述

c. 地理位置与下雨的关系

x=pd.crosstab(data['Location'],data['RainToday']) 
# 获取每个城市下雨天数和非下雨天数的百分比
y=x/x.transpose().sum().values.reshape((-1, 1))*100
# 按每个城市的雨天百分比排序
y=y.sort_values(by='Yes',ascending=True )color=['#cc6699','#006699','#006666','#862d86','#ff9966'  ]
y.Yes.plot(kind="barh",figsize=(15,20),color=color)
<Axes: ylabel='Location'>

在这里插入图片描述

位置影响下雨,对于 Portland 来说,有 36% 的时间在下雨,而对于 Woomers 来说,只有6%的时间在下雨

d. 湿度和压力对下雨的影响

data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine','WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm','WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm','Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am','Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],dtype='object')
plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Pressure9am',y='Pressure3pm',hue='RainTomorrow');

在这里插入图片描述

plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Humidity9am',y='Humidity3pm',hue='RainTomorrow');

在这里插入图片描述


低压与高湿度会增加第二天下雨的概率,尤其是下午 3 点的空气湿度。

e. 气温对下雨的影响

plt.figure(figsize=(8,6))
sns.scatterplot(x='MaxTemp', y='MinTemp', data=data, hue='RainTomorrow');

请添加图片描述

4. 数据预处理

处理缺损值

# 每列中缺失数据的百分比
data.isnull().sum()/data.shape[0]*100
Location          0.000000
MinTemp           1.020899
MaxTemp           0.866905
Rainfall          2.241853
Evaporation      43.166506
Sunshine         48.009762
WindGustDir       7.098859
WindGustSpeed     7.055548
WindDir9am        7.263853
WindDir3pm        2.906641
WindSpeed9am      1.214767
WindSpeed3pm      2.105046
Humidity9am       1.824557
Humidity3pm       3.098446
Pressure9am      10.356799
Pressure3pm      10.331363
Cloud9am         38.421559
Cloud3pm         40.807095
Temp9am           1.214767
Temp3pm           2.481094
RainToday         2.241853
RainTomorrow      2.245978
year              0.000000
Month             0.000000
day               0.000000
dtype: float64
# 在该列中随机选择数进行填充
lst=['Evaporation','Sunshine','Cloud9am','Cloud3pm']
for col in lst:fill_list = data[col].dropna()data[col] = data[col].fillna(pd.Series(np.random.choice(fill_list, size=len(data.index))))
s = (data.dtypes == "object")
object_cols = list(s[s].index)
object_cols
['Location','WindGustDir','WindDir9am','WindDir3pm','RainToday','RainTomorrow']
# inplace=True:直接修改原对象,不创建副本
# data[i].mode()[0] 返回频率出现最高的选项,众数for i in object_cols:data[i].fillna(data[i].mode()[0], inplace=True)
t = (data.dtypes == "float64")
num_cols = list(t[t].index)
num_cols
['MinTemp','MaxTemp','Rainfall','Evaporation','Sunshine','WindGustSpeed','WindSpeed9am','WindSpeed3pm','Humidity9am','Humidity3pm','Pressure9am','Pressure3pm','Cloud9am','Cloud3pm','Temp9am','Temp3pm']
# .median(), 中位数
for i in num_cols:data[i].fillna(data[i].median(), inplace=True)
data.isnull().sum()
Location         0
MinTemp          0
MaxTemp          0
Rainfall         0
Evaporation      0
Sunshine         0
WindGustDir      0
WindGustSpeed    0
WindDir9am       0
WindDir3pm       0
WindSpeed9am     0
WindSpeed3pm     0
Humidity9am      0
Humidity3pm      0
Pressure9am      0
Pressure3pm      0
Cloud9am         0
Cloud3pm         0
Temp9am          0
Temp3pm          0
RainToday        0
RainTomorrow     0
year             0
Month            0
day              0
dtype: int64

5. 构建数据集

from sklearn.preprocessing import LabelEncoderlabel_encoder = LabelEncoder()
for i in object_cols:data[i] = label_encoder.fit_transform(data[i])
X = data.drop(['RainTomorrow','day'],axis=1).values
y = data['RainTomorrow'].values
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=101)
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test  = scaler.transform(X_test)
# 创建pytorch dataset 和 dataloader
"""
在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,
再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。
"""
import torch
from torch.utils.data import Dataset, DataLoader, TensorDatasetX_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)# 如果要做二分类 + Sigmoid + nn.BCELoss,那么标签可以用 float32
# 如果要做多分类(例如 softmax + CrossEntropy),则需把标签转为 long
y_train = y_train.astype(np.float32)  # 二分类: float32
y_test = y_test.astype(np.float32)    # 二分类: float32# 转换为张量
X_train_tensor = torch.from_numpy(X_train).float()  # shape:[samples, 13, 1]
y_train_tensor = torch.from_numpy(y_train)          # shape:[samples]X_test_tensor = torch.from_numpy(X_test).float()
y_test_tensor = torch.from_numpy(y_test)# 如果后续需要在训练中对标签执行 pred>0.5 判定,可以保持 y 的 shape=[samples] 即可
# 也可 reshape([-1,1]) 保持和网络输出尺寸一致,不过这并非必须。
# y_train_tensor = y_train_tensor.view(-1,1)
# y_test_tensor = y_test_tensor.view(-1,1)# 用 TensorDataset 直接封装
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)# 创建 DataLoader
batch_size = 32train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

6. 定义模型

### 构建RNN模型
# -----------------------------
# 1. 定义模型结构
# -----------------------------
class SimpleRNNModel(nn.Module):def __init__(self):super(SimpleRNNModel, self).__init__()# TensorFlow 中 input_shape=(13,1),即序列长度 seq_len = 13,特征维度 input_dim = 1# PyTorch RNN 层若设置 batch_first=True:#   输入张量形状: (batch_size, seq_len, input_dim)#   输出张量形状: (batch_size, seq_len, hidden_size)self.rnn = nn.RNN(input_size=1,         # 对应 TF 的 input_dim=1hidden_size=200,      # 对应 TF 的 RNN(200)batch_first=True,nonlinearity='relu'   # 对应 TF 的 activation='relu' )self.fc1 = nn.Linear(200, 100)  # 对应 Dense(100, activation='relu')self.fc2 = nn.Linear(100, 1)    # 对应 Dense(1, activation='sigmoid')self.sigmoid = nn.Sigmoid()def forward(self, x):# x: [batch_size, 13, 1]# RNN 输出: output, hidden#   output shape = [batch_size, seq_len, hidden_size]#   hidden shape = [num_layers, batch_size, hidden_size]out, hidden = self.rnn(x)# 取最后一个 time_step 的输出, 与 TensorFlow 里 SimpleRNN 的默认行为一致out = out[:, -1, :]  # shape: [batch_size, hidden_size]# 与 Dense(100, relu)out = F.relu(self.fc1(out)) # [batch_size, 100]# 与 Dense(1, sigmoid)out = self.sigmoid(self.fc2(out)) # [batch_size, 1]return out

7. 初始化模型与优化器

# -----------------------------
# 2. 初始化模型与优化器
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = SimpleRNNModel().to(device)
print(model)# 与 TF 中 loss='binary_crossentropy' 对应,PyTorch 用 BCE:nn.BCELoss
loss_fn = nn.BCELoss()# 多分类问题使用nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss()learn_rate = 1e-4  
# learn_rate = 3e-4
lambda1 = lambda epoch:(0.92**(epoch//2))optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法
SimpleRNNModel((rnn): RNN(1, 200, batch_first=True)(fc1): Linear(in_features=200, out_features=100, bias=True)(fc2): Linear(in_features=100, out_features=1, bias=True)(sigmoid): Sigmoid()
)

8. 训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集大小num_batches = len(dataloader)   # 批次数目train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 计算预测pred = model(X).view(-1) # [batch_size]loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录acc与loss# 情况1: 如果是多分类(N>1), pred.shape=[batch_size, N],可以用argmax(1).# train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()# 情况2: 如果是二分类且只有1个输出(使用 Sigmoid),则 pred.shape=[batch_size,1],# 那么可用 (pred>0.5) 转为0/1来比较:pred_label = (pred > 0.5).long() # [batch_size]train_acc += (pred_label == y.long()).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

9. 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)# 计算预测pred = model(X).view(-1) # [batch_size]loss = loss_fn(pred, y)# 情况1: 多分类(N>1):# test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()# 情况2: 二分类单输出:pred_label = (pred > 0.5).long() # [batch_size]# test_acc += (pred_label.view(-1) == y).type(torch.float).sum().item()test_acc += (pred_label == y.long()).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

10. 执行训练

# -----------------------------
# 打印可用 GPU 信息
# -----------------------------
if torch.cuda.is_available():for i in range(torch.cuda.device_count()):print(f"GPU {i}: {torch.cuda.get_device_name(i)}")print(f"Initial Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")print(f"Initial Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:print("No GPU available. Using CPU.")# -----------------------------
# 训练主循环
# -----------------------------
epochs = 60train_acc_list = []
train_loss_list = []
test_acc_list = []
test_loss_list = []best_acc = 0.0
best_model = Nonefor epoch in range(epochs):# 更新学习率——使用自定义学习率时使用# adjust_learning_rate(optimizer,epoch,learn_rate)# 切换为训练模式model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# 更新学习scheduler.step() # 更新学习率——调用官方动态学习率时使用# 切换为评估模式model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc_list.append(epoch_train_acc)train_loss_list.append(epoch_train_loss)test_acc_list.append(epoch_test_acc)test_loss_list.append(epoch_test_loss)# 当前学习lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, ''Train_acc:{:.1f}%, Train_loss:{:.3f}, ''Test_acc:{:.1f}%, Test_loss:{:.3f}, ''Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100, epoch_train_loss,epoch_test_acc*100, epoch_test_loss,lr))# 实时监控 GPU 状态if torch.cuda.is_available():for i in range(torch.cuda.device_count()):print(f"GPU {i} Usage:")print(f"  Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")print(f"  Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")print(f"  Max Memory Allocated: {torch.cuda.max_memory_allocated(i)/1024**2:.2f} MB")print(f"  Max Memory Reserved: {torch.cuda.max_memory_reserved(i)/1024**2:.2f} MB")print('Done. Best test acc: ', best_acc)
No GPU available. Using CPU.
Epoch: 1, Train_acc:80.1%, Train_loss:0.460, Test_acc:82.7%, Test_loss:0.397, Lr:1.00E-04
Epoch: 2, Train_acc:83.4%, Train_loss:0.387, Test_acc:83.8%, Test_loss:0.374, Lr:9.20E-05
Epoch: 3, Train_acc:83.9%, Train_loss:0.375, Test_acc:84.1%, Test_loss:0.367, Lr:9.20E-05
Epoch: 4, Train_acc:83.9%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.365, Lr:8.46E-05
Epoch: 5, Train_acc:84.1%, Train_loss:0.368, Test_acc:83.9%, Test_loss:0.375, Lr:8.46E-05
Epoch: 6, Train_acc:84.3%, Train_loss:0.366, Test_acc:84.3%, Test_loss:0.364, Lr:7.79E-05
Epoch: 7, Train_acc:84.3%, Train_loss:0.365, Test_acc:84.3%, Test_loss:0.363, Lr:7.79E-05
Epoch: 8, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.3%, Test_loss:0.362, Lr:7.16E-05
Epoch: 9, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.4%, Test_loss:0.362, Lr:7.16E-05
Epoch:10, Train_acc:84.4%, Train_loss:0.362, Test_acc:84.3%, Test_loss:0.363, Lr:6.59E-05
Epoch:11, Train_acc:84.3%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.363, Lr:6.59E-05
Epoch:12, Train_acc:84.4%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.359, Lr:6.06E-05
Epoch:13, Train_acc:84.5%, Train_loss:0.360, Test_acc:84.4%, Test_loss:0.362, Lr:6.06E-05
Epoch:14, Train_acc:84.4%, Train_loss:0.360, Test_acc:84.5%, Test_loss:0.359, Lr:5.58E-05
Epoch:15, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.4%, Test_loss:0.358, Lr:5.58E-05
Epoch:16, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.5%, Test_loss:0.361, Lr:5.13E-05
Epoch:17, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.5%, Test_loss:0.358, Lr:5.13E-05
Epoch:18, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:19, Train_acc:84.6%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:20, Train_acc:84.7%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:21, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:22, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:23, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:24, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.5%, Test_loss:0.358, Lr:3.68E-05
Epoch:25, Train_acc:84.7%, Train_loss:0.353, Test_acc:84.6%, Test_loss:0.357, Lr:3.68E-05
Epoch:26, Train_acc:84.8%, Train_loss:0.353, Test_acc:84.7%, Test_loss:0.354, Lr:3.38E-05
Epoch:27, Train_acc:84.7%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.353, Lr:3.38E-05
Epoch:28, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.354, Lr:3.11E-05
Epoch:29, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.354, Lr:3.11E-05
Epoch:30, Train_acc:84.9%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.353, Lr:2.86E-05
Epoch:31, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.7%, Test_loss:0.356, Lr:2.86E-05
Epoch:32, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.6%, Test_loss:0.354, Lr:2.63E-05
Epoch:33, Train_acc:84.8%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.63E-05
Epoch:34, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.7%, Test_loss:0.354, Lr:2.42E-05
Epoch:35, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.42E-05
Epoch:36, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.6%, Test_loss:0.354, Lr:2.23E-05
Epoch:37, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.353, Lr:2.23E-05
Epoch:38, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.6%, Test_loss:0.356, Lr:2.05E-05
Epoch:39, Train_acc:85.0%, Train_loss:0.348, Test_acc:85.0%, Test_loss:0.351, Lr:2.05E-05
Epoch:40, Train_acc:85.0%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:41, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:42, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.9%, Test_loss:0.351, Lr:1.74E-05
Epoch:43, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.350, Lr:1.74E-05
Epoch:44, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.351, Lr:1.60E-05
Epoch:45, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.60E-05
Epoch:46, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.47E-05
Epoch:47, Train_acc:85.0%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.351, Lr:1.47E-05
Epoch:48, Train_acc:85.0%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.35E-05
Epoch:49, Train_acc:85.1%, Train_loss:0.346, Test_acc:85.0%, Test_loss:0.349, Lr:1.35E-05
Epoch:50, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.24E-05
Epoch:51, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.349, Lr:1.24E-05
Epoch:52, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.14E-05
Epoch:53, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.9%, Test_loss:0.349, Lr:1.14E-05
Epoch:54, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:55, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:56, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.8%, Test_loss:0.350, Lr:9.68E-06
Epoch:57, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.348, Lr:9.68E-06
Epoch:58, Train_acc:85.1%, Train_loss:0.344, Test_acc:84.8%, Test_loss:0.350, Lr:8.91E-06
Epoch:59, Train_acc:85.2%, Train_loss:0.344, Test_acc:85.0%, Test_loss:0.348, Lr:8.91E-06
Epoch:60, Train_acc:85.2%, Train_loss:0.344, Test_acc:84.9%, Test_loss:0.349, Lr:8.20E-06
Done. Best test acc:  0.8500206242265915

11. 过程可视化

epochs_range = range(epochs)plt.figure(figsize=(12, 5))# 准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc_list, label='Training Accuracy')
plt.plot(epochs_range, test_acc_list, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')# 损失曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss_list, label='Training Loss')
plt.plot(epochs_range, test_loss_list, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')plt.show()

请添加图片描述



http://www.ppmy.cn/server/155830.html

相关文章

Java精心打造:安全可靠的共享棋牌室系统

软件开发技术栈 java后台技术 springbootmybatismysqlredis 小程序用户端、小程序门店管理员端、小程序门店保洁员端 uniapp 总后台、门店PC管理前端 vue elementUI 用户预约小程序 小程序自定义 支持独立小程序运行&#xff0c;可以自己设置小程序的名称、logo、加盟电话…

设计心得——流程图和数据流图绘制

一、流程图和数据流图 在软件开发中&#xff0c;画流程图和数据流图可以说是几乎每个人都会遇到。 1、数据流&#xff08;程&#xff09;图 Data Flow Diagram&#xff0c;DFG。它可以称为数据流图或数据流程图。其主要用来描述系统中数据流程的一种图形工具&#xff0c;可以将…

mysql error:1071 -Specified key was too long; max key length is 767 bytes

错误原因 数据库表采用utf8编码时&#xff0c;当对varchar(255)的列设置唯一键索引时发生该错误。 mysql默认单列的索引不能超过767位(不同版本可能存在差异) 解决方法 &#xff08;1&#xff09; 使用innodb引擎&#xff1b; &#xff08;2&#xff09; 启用innodb_large_p…

WebRTC的线程模型

WebRTC中的线程类&#xff1a; Thread类&#xff1a; &#xff08;1&#xff09;Thread类中的数据&#xff1a; class Thread {// 消息队列&#xff1a;MessageList messages_; // 消息队列&#xff0c;所有需要线程处理的消息&#xff0c;都要先入队PriorityQueue delayed_m…

【《python爬虫入门教程11--重剑无峰168》】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 【《python爬虫入门教程11--selenium的安装与使用》】 前言selenium就是一个可以实现python自动化的模块 一、Chrome的版本查找&#xff1f;-- 如果用edge也是类似的1.chrome…

大数据学习(33)-续集

今天开始重新更新大数据 -- 感谢大家的支持&#xff01;&#xff01;&#xff01;

[Linux]进程间通信-管道

目录 1. 进程间通信 2.父子进程之间的通信 3.匿名管道 匿名管道的创建 管道读写的情况 管道的5种特性 4.命名管道 指令级 命名管道原理 代码级 读取端 1. 进程间通信 当我们有两个进程操作数据库的时候&#xff0c;一个进程负责写入操作&#xff0c;一个进…

npm提示Install fail! Error_ EBUSY_ resource busy or

问题 在命令行下&#xff0c;通过NPM 命令来安装插件&#xff0c;弹出提示Install fail! Error: EBUSY: resource busy or locked, symlink 解决方法 出现这样的错误大概率是文件被占用&#xff0c;导致文件或者文件夹无法删除造成的&#xff0c; 1.尝试执行npm cache clea…