pytorch深度Q网络

ops/2025/2/4 20:55:09/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

DQN 引入了深度神经网络来近似Q函数,解决了传统Q-learning在处理高维状态空间时的瓶颈,尤其是在像 Atari 游戏这样的复杂环境中。DQN的核心思想是使用神经网络 Q(s,a;θ)Q(s, a; \theta)Q(s,a;θ) 来近似 Q 值函数,其中 θ\thetaθ 是神经网络的参数。

DQN 的关键创新包括:

  1. 经验回放(Experience Replay):在强化学习中,当前的学习可能会依赖于最近的经验,容易导致学习过程的不稳定。经验回放通过将智能体的经历存储到一个回放池中,然后随机抽取批量数据进行训练,这样可以打破数据之间的相关性,使得训练更加稳定。

  2. 目标网络(Target Network):在Q-learning中,Q值的更新依赖于下一个状态的最大Q值。为了避免Q值更新时过度依赖当前网络的输出(导致不稳定),DQN引入了目标网络。目标网络的结构与行为网络相同,但它的参数更新频率较低,这使得Q值更新更加稳定。

DQN算法流程

  1. 初始化Q网络:初始化Q网络的参数 θ\thetaθ,以及目标网络的参数 θ−\theta^-θ−(通常与Q网络相同)。
  2. 行为选择:基于当前的Q网络来选择动作(通常使用ε-greedy策略,即以ε的概率选择随机动作,否则选择当前Q值最大的动作)。
  3. 执行动作并存储经验:执行所选动作,观察奖励,并记录状态转移 (st,at,rt+1,st+1)(s_t, a_t, r_{t+1}, s_{t+1})(st​,at​,rt+1​,st+1​)。
  4. 经验回放:从回放池中随机抽取一个小批量的经验数据。
  5. 计算Q值目标:对于每个样本,计算目标值 y=rt+1+γmax⁡a′Q(st+1,a′;θ−)y = r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a'; \theta^-)y=rt+1​+γmaxa′​Q(st+1​,a′;θ−)。
  6. 更新Q网络:通过最小化损失函数 L(θ)=1N∑(y−Q(st,at;θ))2L(\theta) = \frac{1}{N} \sum (y - Q(s_t, a_t; \theta))^2L(θ)=N1​∑(y−Q(st​,at​;θ))2 来更新Q网络的参数。
  7. 周期性更新目标网络:每隔一段时间,将Q网络的参数复制到目标网络。

DQN的应用

DQN在多个领域取得了重要应用,尤其是在强化学习任务中:

  • Atari 游戏:DQN 在多个经典的 Atari 游戏上成功展示了其能力,比如《Breakout》和《Pong》等。
  • 机器人控制:利用DQN,机器人可以在复杂的环境中自主学习如何执行任务。
  • 自动驾驶:在自动驾驶领域,DQN可以用来训练智能体通过道路、避开障碍物等。

例子:

这里我们手动实现一个非常简单的环境:一个1D平衡问题,类似于一个可以左右移动的棒球,目标是让它保持在某个位置上。

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 自定义环境
class SimpleEnv:def __init__(self):self.state = 0.0  # 初始状态self.goal = 10.0  # 目标位置self.done = Falsedef reset(self):self.state = 0.0self.done = Falsereturn self.statedef step(self, action):if self.done:return self.state, 0, self.done  # 游戏结束,不再变化# 通过动作修改状态self.state += action  # 动作是 -1、0、1,控制移动方向reward = -abs(self.state - self.goal)  # 奖励是距离目标位置的负值# 如果距离目标很近,就结束if abs(self.state - self.goal) < 0.1:self.done = Truereward = 10  # 达到目标时奖励较高return self.state, reward, self.done# Q网络定义
class QNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(QNetwork, self).__init__()self.fc = nn.Linear(input_dim, 24)self.fc2 = nn.Linear(24, output_dim)def forward(self, x):x = torch.relu(self.fc(x))x = self.fc2(x)return x# DQN智能体
class DQN:def __init__(self, env, gamma=0.99, epsilon=0.1, batch_size=32, learning_rate=1e-3):self.env = envself.gamma = gammaself.epsilon = epsilonself.batch_size = batch_sizeself.learning_rate = learning_rateself.input_dim = 1  # 因为环境状态是一个单一的数值self.output_dim = 3  # 动作空间大小:-1, 0, 1self.q_network = QNetwork(self.input_dim, self.output_dim)self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)self.criterion = nn.MSELoss()def select_action(self, state):if random.random() < self.epsilon:return random.choice([-1, 0, 1])  # 随机选择动作state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)with torch.no_grad():q_values = self.q_network(state)# 将动作值 -1, 0, 1 转换为索引 0, 1, 2action_idx = torch.argmax(q_values, dim=1).item()action_map = [-1, 0, 1]  # -1 -> 0, 0 -> 1, 1 -> 2return action_map[action_idx]def update(self, state, action, reward, next_state, done):state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)# 将动作 -1, 0, 1 转换为索引 0, 1, 2action_map = [-1, 0, 1]action_idx = action_map.index(action)action = torch.tensor(action_idx, dtype=torch.long).unsqueeze(0)reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)# 确保done是Python标准bool类型done = torch.tensor(done, dtype=torch.float32).unsqueeze(0)# 计算目标Q值with torch.no_grad():next_q_values = self.q_network(next_state)next_q_value = next_q_values.max(1)[0]target_q_value = reward + self.gamma * next_q_value * (1 - done)# 获取当前Q值q_values = self.q_network(state)action_q_values = q_values.gather(1, action.unsqueeze(1)).squeeze(1)# 计算损失并更新Q网络loss = self.criterion(action_q_values, target_q_value)self.optimizer.zero_grad()loss.backward()self.optimizer.step()def train(self, num_episodes=200):rewards = []best_reward = -float('inf')  # 初始最好的奖励设为负无穷best_episode = 0for episode in range(num_episodes):state = self.env.reset()  # 获取初始状态total_reward = 0done = Falsewhile not done:action = self.select_action([state])next_state, reward, done = self.env.step(action)total_reward += reward# 更新Q网络self.update([state], action, reward, [next_state], done)state = next_staterewards.append(total_reward)# 记录最佳奖励和对应的episodeif total_reward > best_reward:best_reward = total_rewardbest_episode = episodeprint(f"Episode {episode}, Total Reward: {total_reward}")# 打印最佳结果print(f"Best Reward: {best_reward} at Episode {best_episode}")# 绘制奖励图plt.plot(rewards)plt.title('Total Rewards per Episode')plt.xlabel('Episode')plt.ylabel('Total Reward')# 在最佳位置添加标记plt.scatter(best_episode, best_reward, color='red', label=f"Best Reward at Episode {best_episode}")plt.legend()plt.show()# 初始化环境和DQN智能体
env = SimpleEnv()
dqn = DQN(env)# 训练智能体
dqn.train()


http://www.ppmy.cn/ops/155668.html

相关文章

初始JavaEE篇 —— Spring Web MVC入门(上)

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a;JavaEE 目录 RequestMappingg 注解介绍 Postman的介绍与使用 PostMapping 与 GetMapping 注解 构造并接收请求 接收简单参数 接收对象…

pytorch实现简单的情感分析算法

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 在PyTorch中实现中文情感分析算法通常涉及以下几个步骤&#xff1a;数据预处理、模型定义、训练和评估。下面是一个简单的实现示例&#xff0c;使用LSTM模型进行中文情感分析。 1. 数据预处理 首先&#xff0c;我…

8、面向对象:类、封装、构造方法

一、类 1、定义 类&#xff1a;对现实世界中事物的抽象。Student 对象&#xff1a;现实世界中具体的个体。张三、李四 这些具体的学生 面向对象的特征&#xff1a;抽象、封装、继承、多态 OOP: Object Oriented Programming&#xff08;面向对象编程&#xff09; 类和对象…

Paddle和pytorch不可以同时引用

import paddleprint(paddle.utils.run_check())import torch print(torch.version.cuda)print(torch.backends.cudnn.version()) 报错&#xff1a; OSError: [WinError 127] 找不到指定的程序。 Error loading "C:\Program Files\Python311\Lib\site-packages\torch\li…

求职刷题力扣DAY34--贪心算法part05

Definition for a binary tree node. class TreeNode: def init(self, val0, leftNone, rightNone): self.val val self.left left self.right right class Solution: def minCameraCover(self, root: Optional[TreeNode]) -> int: # 三种状态0&#xff1a;没有覆盖…

深度学习编译器的演进:从计算图到跨硬件部署的自动化之路

第一章 问题的诞生——深度学习部署的硬件困境 1.1 计算图的理想化抽象 什么是计算图&#xff1f; 想象你正在组装乐高积木。每个积木块代表一个数学运算&#xff08;如加法、乘法&#xff09;&#xff0c;积木之间的连接代表数据流动。深度学习框架正是用这种"积木拼接…

MySQL(InnoDB统计信息)

后面也会持续更新&#xff0c;学到新东西会在其中补充。 建议按顺序食用&#xff0c;欢迎批评或者交流&#xff01; 缺什么东西欢迎评论&#xff01;我都会及时修改的&#xff01; 大部分截图和文章采用该书&#xff0c;谢谢这位大佬的文章&#xff0c;在这里真的很感谢让迷茫的…

[SAP ABAP] ABAP SQL跟踪工具

事务码ST05 操作步骤 步骤1&#xff1a;使用事务码ST05之前&#xff0c;将要检测的程序生成的页面先呈现出来&#xff0c;这里我们想看下面程序的取数操作&#xff0c;所以停留在选择界面 步骤2&#xff1a; 新建一个GUI窗口&#xff0c;输入事务码ST05&#xff0c;点击 Acti…