【机器学习】机器学习的基本分类-强化学习-Deep Q-Network (DQN)

server/2024/12/19 11:29:55/

Deep Q-Network (DQN) 是 Q-Learning 的扩展版本,通过使用深度神经网络来逼近 Q 函数,解决了 Q-Learning 在高维状态空间上的适用性问题。DQN 是深度强化学习的里程碑之一,其突破性地在 Atari 游戏上表现出了超过人类玩家的水平。


DQN 的核心思想

DQN 使用一个神经网络 Q_\theta(s, a) 来逼近状态-动作值函数 Q(s, a)。通过不断地更新网络参数 θ\thetaθ,使其逼近真实的 Q^*(s, a)
其主要改进在于解决了传统 Q-Learning 中 不稳定性发散性 的问题。


DQN 的改进与关键技术

  1. 经验回放(Experience Replay)

    • 将智能体的交互数据存储到一个 回放缓冲区(Replay Buffer)中。
    • 随机采样小批量数据进行训练,以减少样本之间的相关性,提高数据利用率。
  2. 目标网络(Target Network)

    • 引入一个与主网络结构相同但参数固定的 目标网络 Q_{\theta'}(s, a)
    • 每隔一定步数,将主网络的参数 θ\thetaθ 同步到目标网络上,减缓更新的频繁波动。
  3. 奖励剪辑(Reward Clipping)

    • 将奖励值裁剪到 [-1, 1],防止过大值影响梯度更新的稳定性。

DQN 的工作流程

  1. 初始化

    • 初始化主网络 Q_\theta(s, a) 和目标网络 Q_{\theta'}
    • 初始化经验回放缓冲区 D。
  2. 采样交互数据

    • 当前状态 sss 下,按照 \epsilon-贪婪策略选择动作 a:

      • 以 ϵ 的概率随机探索。
      • 以 1−ϵ 的概率选择最大 Q_\theta(s, a)的动作。
    • 执行动作 a,观察即时奖励 R 和下一状态 s′。

    • 将 (s, a, R, s') 存入经验回放缓冲区 D。

  3. 更新网络参数

    • 从 D 中随机采样一个小批量 (s, a, R, s')。
    • 计算目标值(TD 目标):

                                              y = R + \gamma \max_{a'} Q_{\theta'}(s', a')
    • 计算均方误差(MSE)损失:

                                            L(\theta) = \mathbb{E}_{(s, a, R, s') \sim D} \left[ \left( y - Q_\theta(s, a) \right)^2 \right][(y−Qθ​(s,a))2]
    • 使用梯度下降更新主网络参数 θ。
  4. 同步目标网络

    • 每隔固定步数,将主网络的参数 θ 同步到目标网络 θ′。
  5. 迭代训练

    • 重复上述步骤,直到收敛。

伪代码

Initialize Q-network with random weights θ
Initialize target network Q_target with weights θ_target = θ
Initialize replay buffer Dfor episode in range(max_episodes):Initialize state sfor t in range(max_steps_per_episode):# ε-greedy action selectionif random.random() < ε:a = random_action()else:a = argmax(Q(s, a; θ))# Execute action and observe next state and rewards', R, done = environment.step(a)# Store transition in replay bufferD.append((s, a, R, s'))# Sample random minibatch from replay bufferminibatch = random.sample(D, batch_size)# Compute target valuey = R + γ * max(Q_target(s', a'; θ_target)) if not done else R# Compute loss and update Q-networkloss = (y - Q(s, a; θ))^2Perform gradient descent on θ to minimize loss# Update target networkif t % target_update_freq == 0:θ_target ← θif done:break

优缺点

优点
  1. 高效处理高维状态空间:使用神经网络学习 Q(s, a),适用于图像等复杂输入。
  2. 数据利用率高:经验回放缓冲区减少了样本相关性,提高了数据效率。
  3. 稳定性增强:目标网络缓解了更新发散问题。
缺点
  1. 不适用于连续动作空间:DQN 假设动作空间是离散的。
  2. 样本效率低于新方法:如基于策略的算法和 Actor-Critic 方法。
  3. 容易过拟合到训练环境:需要精心设计探索策略。

改进版本

  1. Double DQN

    • 解决 DQN 中 max⁡ 运算导致的 值过高估计 问题。
    • 目标值:

                   y = R + \gamma Q_{\theta'}(s', \arg\max_{a'} Q_\theta(s', a'))
  2. Dueling DQN

    • 将 Q 网络拆分为 状态价值函数 V(s)优势函数 A(s, a)

                              Q(s, a) = V(s) + A(s, a)
  3. Prioritized Experience Replay

    • 通过为经验分配优先级,增加对高 TD 误差样本的采样频率。
  4. Rainbow DQN

    • 集成了多种改进,包括 Double DQN、Dueling DQN、Prioritized Replay、Noisy Networks 等。

应用场景

  1. Atari 游戏

    • 使用原始图像像素作为输入,DQN 在许多 Atari 游戏中实现了超越人类玩家的表现。
  2. 自动驾驶

    • 处理离散决策问题,如车道选择。
  3. 动态资源分配

    • 云计算中的任务分配和调度。
  4. 推荐系统

    • 优化用户交互中的点击率。

http://www.ppmy.cn/server/151427.html

相关文章

mongodb应用心得

基于springboot做mysql业务基础数据分析到mongodb文档库 索引分析 查看当前集合索引&#xff1a;db.collection.getIndexes() explain 方法查看是如何执行的&#xff1a;db.users.find({ name: “John” }).sort({ age: -1 }).explain(“executionStats”) 参数指标&#xff1…

在 Ubuntu 下通过 Docker 部署 Cloudflared Tunnel 服务器

Cloudflared 是 Cloudflare 提供的一个命令行工具&#xff0c;用于创建安全的隧道&#xff0c;连接本地服务器与 Cloudflare 的网络。通过 Cloudflared Tunnel&#xff0c;可以轻松实现安全的远程访问&#xff0c;保护应用程序和服务免受 DDoS 攻击。Docker 则是一个强大的容器…

使用xpath规则进行提取数据并存储

下载lxml !pip install lxmlimport requests headers{"user-agent":"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.6261.95 Safari/537.36" } url"https://movie.douban.com/chart" respon…

复习打卡Linux篇

目录 1. Linux常用操作命令 2. vim编辑器 3. 用户权限 4. Linux系统信息查看 1. Linux常用操作命令 基础操作&#xff1a; 命令说明history查看历史执行命令ls查看指定目录下内容ls -a查看所有文件 包括隐藏文件ls -l ll查看文件详细信息&#xff0c;包括权限类型时间大小…

基于单片机的智能灯光控制系统

摘要 现在的大部分的大学&#xff0c;都是采用了一种“绿色”的教学方式&#xff0c;再加上现在的大学生缺乏环保意识&#xff0c;所以在学校里很多的教室&#xff0c;在白天的时候灯都会打开&#xff0c;这是一种极大的浪费&#xff0c;而且随时都有可能看到&#xff0c;这是…

【自适应】postcss-pxtorem适配Web端页面

在进行页面开发时&#xff0c;自适应设计是一个关键的考虑因素。为了实现这一点&#xff0c;postcss-pxtorem是一个非常有用的工具&#xff0c;它可以将CSS中的px单位转换为rem单位&#xff0c;从而实现基于根元素字体大小的自适应布局。下面介绍一下在项目中如何引入并配置pos…

[OpenGL] Transform feedback 介绍以及使用示例

一、简介 本文介绍了 OpenGL 中 Transform Feedback 方法的基本概念和代码示例。 二、Transform Feedback 介绍 1. Transform Feedback 简介 根据 OpenGL-wiki&#xff0c;Transform Feedback 是捕获由顶点处理步骤&#xff08;vertex shader 和 geometry shader&#xff0…

游戏引擎学习第48天

仓库: https://gitee.com/mrxiao_com/2d_game 回顾 我们正在进行碰撞检测的工作&#xff0c;昨天我们几乎完成了一部分代码。由于一些原因&#xff0c;昨天的直播结束时未能完成所有内容。今天我们将继续进行&#xff0c;首先回顾一下之前的进展。我们需要让角色能够正确地与…