PyTorch 中detach的使用:以强化学习中Q-Learning的目标值计算为例

server/2024/12/17 22:13:09/

PyTorch 中 detach 的使用:以强化学习中的目标值计算为例

在强化学习(Reinforcement Learning, RL)中,detach 是一个非常重要的工具。它常用于目标值(Target Value)的计算,确保目标值的梯度不会反向传播到某些特定的神经网络中。例如,在 Q-Learning 等方法中,目标值的计算需要与当前 Q 网络的更新解耦,而 detach 就是在这个场景中广泛使用的工具。

本文将通过一个具体的代码示例,详细介绍 detach 的作用及其在 Q-Learning 中的应用,帮助你理解它是如何工作的。


1. 强化学习中的 Q-Learning 简介

1.1 Q-Learning 的基本思想

Q-Learning 是一种基于值的强化学习算法,其目标是学习一个 Q 函数 ( Q ( s , a ) Q(s, a) Q(s,a) ),表示在状态 ( s s s ) 下选择动作 ( a a a ) 所能获得的期望累积奖励。公式如下:

Q ( s , a ) = r + γ max ⁡ a ′ Q ( s ′ , a ′ ) Q(s, a) = r + \gamma \max_{a'} Q(s', a') Q(s,a)=r+γamaxQ(s,a)

  • ( r r r ):即时奖励(Reward)。
  • ( γ \gamma γ ):折扣因子(Discount Factor),用于衡量未来奖励的重要性。
  • ( max ⁡ a ′ Q ( s ′ , a ′ ) \max_{a'} Q(s', a') maxaQ(s,a) ):下一个状态 ( s ′ s' s ) 中最优动作的 Q 值。

在训练过程中,Q 网络的参数通过以下目标更新:

Loss = ( Q ( s , a ) − Target ( s , a ) ) 2 \text{Loss} = \left( Q(s, a) - \text{Target}(s, a) \right)^2 Loss=(Q(s,a)Target(s,a))2

其中,目标值 ( Target ( s , a ) \text{Target}(s, a) Target(s,a) ) 的计算依赖于目标 Q 网络或冻结的 Q 值,避免其梯度直接影响当前网络的更新。


2. 为什么使用 detach

2.1 防止梯度传播

在 Q-Learning 的目标值计算中,下一状态的 Q 值 ( max ⁡ a ′ Q ( s ′ , a ′ ) \max_{a'} Q(s', a') maxaQ(s,a) ) 不应该参与当前网络参数的更新,因为它属于目标网络或冻结的 Q 值。通过 detach,我们可以从计算图中分离这些值,确保它们的梯度不会影响反向传播。

2.2 提高稳定性

如果目标值直接参与梯度传播,训练可能会出现不稳定甚至发散的情况。通过 detach,可以保证目标值是固定的,从而提高训练的稳定性。


3. 代码示例:Q-Learning 中的目标值计算

以下代码展示了如何使用 detach 分离目标值的梯度计算,确保 Q 网络的更新仅基于当前状态的 Q 值,而不受目标值梯度的影响。

python">import torch# 当前 Q 网络的输出(例如,q_values 表示 Q(s, a))
q_values = torch.tensor([10.0, 20.0, 30.0], requires_grad=True)# 下一状态的 Q 值(例如,next_q_values 表示 max_a' Q(s', a'))
next_q_values = torch.tensor([15.0, 25.0, 35.0], requires_grad=True)# 目标值计算:使用 detach 防止 next_q_values 的梯度传播
gamma = 0.9  # 折扣因子
reward = 1   # 即时奖励
target_q_values = (next_q_values.detach() * gamma) + reward# 损失函数计算
loss = ((q_values - target_q_values) ** 2).mean()# 反向传播
loss.backward()# 打印 q_values 的梯度
print("q_values 的梯度:", q_values.grad)

4. 代码解析

4.1 q_valuesnext_q_values 的定义
python">q_values = torch.tensor([10.0, 20.0, 30.0], requires_grad=True)
next_q_values = torch.tensor([15.0, 25.0, 35.0], requires_grad=True)
  • q_values 表示当前 Q 网络输出的 Q 值。
  • next_q_values 表示下一状态的 Q 值,用于目标值的计算。

两者的 requires_grad=True 表明它们会记录梯度信息。

4.2 detach 的作用
python">target_q_values = (next_q_values.detach() * gamma) + reward
  • 通过 detach(),从计算图中分离出 next_q_values
  • 效果next_q_values 的梯度不会在目标值计算中传播,这保证了目标值是固定的,不影响反向传播。
4.3 损失计算与反向传播
python">loss = ((q_values - target_q_values) ** 2).mean()
loss.backward()
  • loss 是当前 Q 值与目标值之间的均方误差。
  • loss.backward() 计算梯度,此时:
    • q_values 的梯度会被计算并用于更新参数。
    • next_q_values 不参与梯度传播,因为它已被 detach
4.4 输出结果

运行代码后,输出如下:

cq_values 的梯度: tensor([-3.0000, -2.3333, -1.6667])

梯度表示每个 Q 值相对于损失的变化率,用于优化参数。


5. 进一步讨论

5.1 强化学习中的梯度计算

在强化学习中,目标值通常通过固定的目标网络(Target Network)或当前网络的快照计算。detach 可以模拟目标网络的行为,减少计算资源占用,同时避免梯度传播。

5.2 对比 detach 和目标网络

虽然 detach 和目标网络在功能上类似,但目标网络通常需要独立更新参数(如定期同步主网络),而 detach 只是一种简单的梯度分离操作。


6. 总结

本文通过 Q-Learning 的目标值计算,详细介绍了 detach 的作用和用法。在强化学习中,detach 是实现目标值计算的重要工具,可以防止梯度传播,提高训练的稳定性。在实际应用中,detach 的灵活性使其广泛用于各种需要冻结计算图的场景。

通过本文的学习,相信你对 detach 在深度学习中的应用有了更深入的理解,尤其是在强化学习中的重要性。

附录:具体梯度计算过程

以下是完整的梯度计算步骤,以便更清晰地理解代码中 loss.backward() 的作用及 PyTorch 的自动求导机制如何计算梯度。


1. 定义变量和公式

已知的变量
  • ( q _ v a l u e s = [ 10.0 , 20.0 , 30.0 ] q\_values = [10.0, 20.0, 30.0] q_values=[10.0,20.0,30.0] )
  • ( n e x t _ q _ v a l u e s = [ 15.0 , 25.0 , 35.0 ] next\_q\_values = [15.0, 25.0, 35.0] next_q_values=[15.0,25.0,35.0] )
  • 折扣因子 ( γ = 0.9 \gamma = 0.9 γ=0.9 )
  • 即时奖励 ( r e w a r d = 1 reward = 1 reward=1 )
目标值的计算

目标值 ( t a r g e t _ q _ v a l u e s target\_q\_values target_q_values ) 计算公式为:
t a r g e t _ q _ v a l u e s = n e x t _ q _ v a l u e s ⋅ γ + r e w a r d target\_q\_values = next\_q\_values \cdot \gamma + reward target_q_values=next_q_valuesγ+reward

代入具体数值:
t a r g e t _ q _ v a l u e s = [ 15.0 ⋅ 0.9 + 1 , 25.0 ⋅ 0.9 + 1 , 35.0 ⋅ 0.9 + 1 ] = [ 14.5 , 23.5 , 32.5 ] target\_q\_values = [15.0 \cdot 0.9 + 1, 25.0 \cdot 0.9 + 1, 35.0 \cdot 0.9 + 1] = [14.5, 23.5, 32.5] target_q_values=[15.00.9+1,25.00.9+1,35.00.9+1]=[14.5,23.5,32.5]

损失函数

损失函数定义为:
loss = 1 n ∑ i = 1 n ( q _ v a l u e s [ i ] − t a r g e t _ q _ v a l u e s [ i ] ) 2 \text{loss} = \frac{1}{n} \sum_{i=1}^n (q\_values[i] - target\_q\_values[i])^2 loss=n1i=1n(q_values[i]target_q_values[i])2

展开为:
loss = 1 3 ( ( 10.0 − 14.5 ) 2 + ( 20.0 − 23.5 ) 2 + ( 30.0 − 32.5 ) 2 ) \text{loss} = \frac{1}{3} \left( (10.0 - 14.5)^2 + (20.0 - 23.5)^2 + (30.0 - 32.5)^2 \right) loss=31((10.014.5)2+(20.023.5)2+(30.032.5)2)

具体计算:
loss = 1 3 ( 20.25 + 12.25 + 6.25 ) = 1 3 ⋅ 38.75 = 12.9167 \text{loss} = \frac{1}{3} \left( 20.25 + 12.25 + 6.25 \right) = \frac{1}{3} \cdot 38.75 = 12.9167 loss=31(20.25+12.25+6.25)=3138.75=12.9167


2. 梯度计算公式

梯度的定义

根据链式法则,对于 ( q _ v a l u e s [ i ] q\_values[i] q_values[i] ),梯度为:
∂ loss ∂ q _ v a l u e s [ i ] = 2 n ( q _ v a l u e s [ i ] − t a r g e t _ q _ v a l u e s [ i ] ) \frac{\partial \text{loss}}{\partial q\_values[i]} = \frac{2}{n} (q\_values[i] - target\_q\_values[i]) q_values[i]loss=n2(q_values[i]target_q_values[i])

其中:

  • ( n = 3 n = 3 n=3 ) 是样本数。
  • ( q _ v a l u e s [ i ] q\_values[i] q_values[i] ) 是当前的 Q 值。
  • ( t a r g e t _ q _ v a l u e s [ i ] target\_q\_values[i] target_q_values[i] ) 是目标值。

3. 分步计算梯度

第一个元素 ( q _ v a l u e s [ 0 ] q\_values[0] q_values[0] ) 的梯度

∂ loss ∂ q _ v a l u e s [ 0 ] = 2 3 ( 10.0 − 14.5 ) \frac{\partial \text{loss}}{\partial q\_values[0]} = \frac{2}{3} (10.0 - 14.5) q_values[0]loss=32(10.014.5)
计算:
∂ loss ∂ q _ v a l u e s [ 0 ] = 2 3 ⋅ ( − 4.5 ) = − 3.0 \frac{\partial \text{loss}}{\partial q\_values[0]} = \frac{2}{3} \cdot (-4.5) = -3.0 q_values[0]loss=32(4.5)=3.0

第二个元素 ( q_values[1] ) 的梯度

∂ loss ∂ q _ v a l u e s [ 1 ] = 2 3 ( 20.0 − 23.5 ) \frac{\partial \text{loss}}{\partial q\_values[1]} = \frac{2}{3} (20.0 - 23.5) q_values[1]loss=32(20.023.5)
计算:
∂ loss ∂ q _ v a l u e s [ 1 ] = 2 3 ⋅ ( − 3.5 ) = − 2.3333 \frac{\partial \text{loss}}{\partial q\_values[1]} = \frac{2}{3} \cdot (-3.5) = -2.3333 q_values[1]loss=32(3.5)=2.3333

第三个元素 ( q _ v a l u e s [ 2 ] q\_values[2] q_values[2] ) 的梯度

∂ loss ∂ q _ v a l u e s [ 2 ] = 2 3 ( 30.0 − 32.5 ) \frac{\partial \text{loss}}{\partial q\_values[2]} = \frac{2}{3} (30.0 - 32.5) q_values[2]loss=32(30.032.5)
计算:
∂ loss ∂ q _ v a l u e s [ 2 ] = 2 3 ⋅ ( − 2.5 ) = − 1.6667 \frac{\partial \text{loss}}{\partial q\_values[2]} = \frac{2}{3} \cdot (-2.5) = -1.6667 q_values[2]loss=32(2.5)=1.6667


4. 梯度结果

梯度张量为:
q _ v a l u e s . g r a d = tensor ( [ − 3.0 , − 2.3333 , − 1.6667 ] ) q\_values.grad = \text{tensor}([-3.0, -2.3333, -1.6667]) q_values.grad=tensor([3.0,2.3333,1.6667])


5. 代码验证

在代码中运行上述逻辑会得出一致的结果:

python">import torch# 当前 Q 网络的输出(例如,q_values 表示 Q(s, a))
q_values = torch.tensor([10.0, 20.0, 30.0], requires_grad=True)# 下一状态的 Q 值(例如,next_q_values 表示 max_a' Q(s', a'))
next_q_values = torch.tensor([15.0, 25.0, 35.0], requires_grad=True)# 目标值计算:使用 detach 防止 next_q_values 的梯度传播
gamma = 0.9  # 折扣因子
reward = 1   # 即时奖励
target_q_values = (next_q_values.detach() * gamma) + reward# 损失函数计算
loss = ((q_values - target_q_values) ** 2).mean()# 反向传播
loss.backward()# 打印 q_values 的梯度
print("q_values 的梯度:", q_values.grad)

输出:

q_values 的梯度: tensor([-3.0000, -2.3333, -1.6667])

6. 总结

通过上述过程,我们清楚地看到:

  1. 目标值的梯度未传播detach 分离了 next_q_values,避免了它对 loss.backward() 的影响。
  2. 梯度计算q_values 的梯度完全由损失函数定义,并且每个梯度的计算基于链式法则。
  3. 代码验证一致:计算结果与代码输出一致,证明了梯度计算过程的正确性。

这解释了 PyTorch 自动求导机制中 detach 的实际作用及其对梯度传播的控制能力。

附录:如何知道对谁求梯度?

这个例子中怎么知道是对谁求梯度?为什么是对q_value求

在 PyTorch 中,是否对某个变量求梯度取决于以下几个关键因素:


1. 梯度计算的核心原则

  • requires_grad=True:只有设置了 requires_grad=True 的张量才会被记录在计算图中,并在反向传播时计算其梯度。
  • 计算图的构建:PyTorch 会自动追踪涉及到 requires_grad=True 张量的操作,并建立一个动态计算图(Dynamic Computational Graph)。
  • backward() 的起点:反向传播通过 loss.backward() 启动,梯度计算从标量(标量损失值)开始,沿着计算图的路径依次对每个 requires_grad=True 的张量求偏导数。

2. 为什么是对 q_values 求梯度

在这个例子中:

  1. q_values 设置了 requires_grad=True:代码中 q_values = torch.tensor([10.0, 20.0, 30.0], requires_grad=True) 明确指定了我们需要对 q_values 求梯度。这意味着,q_values 是一个需要优化的变量

  2. 计算图的构建

    • target_q_values 是通过 next_q_values.detach() 计算得到的,其中 detach() 会阻止梯度从目标值传播回 next_q_values
    • 损失函数 lossq_valuestarget_q_values 的函数,因此 q_values 是损失的一个直接输入。
  3. 优化目标

    • 强化学习中的 q_values 通常对应于当前策略的预测值(例如 ( Q(s, a) )),我们希望通过梯度下降优化 q_values 的网络参数,以最小化损失。

因此,梯度计算的目标自然是 q_values


3. 为什么不是对 next_q_values 求梯度

next_q_values 的创建方式为:

python">next_q_values = torch.tensor([15.0, 25.0, 35.0], requires_grad=True)

虽然 next_q_valuesrequires_grad=True,但在目标值计算中,我们使用了 next_q_values.detach()

python">target_q_values = (next_q_values.detach() * gamma) + reward
  • detach() 的作用detach() 会从计算图中分离出 next_q_values,使得其在后续计算中不再参与梯度传播。
  • 目标:在强化学习中,next_q_values 通常是通过目标网络计算的值。使用 detach() 是为了确保它不会影响当前 Q 网络的梯度更新。

因此,loss.backward() 时梯度不会传播到 next_q_values


4. 反向传播流程

在反向传播中,loss.backward() 触发如下过程:

  1. 计算损失函数:
    loss = 1 3 ∑ ( q _ v a l u e s [ i ] − t a r g e t _ q _ v a l u e s [ i ] ) 2 \text{loss} = \frac{1}{3} \sum (q\_values[i] - target\_q\_values[i])^2 loss=31(q_values[i]target_q_values[i])2
  2. 按照计算图,从损失开始,沿着计算图依次对每个 requires_grad=True 的张量计算梯度。
  3. 因为 target_q_values 是通过 next_q_values.detach() 计算的,计算图中只有 q_values 会被追踪并计算梯度。

5. 总结:如何知道对谁求梯度

  • 是否追踪计算图:只对 requires_grad=True 的张量计算梯度。
  • 是否分离计算图:如果通过 detach() 分离了计算图,则梯度不会传播到分离的张量。
  • 梯度计算的目标:在反向传播时,PyTorch 会自动沿着计算图从损失出发,对所有需要梯度的张量计算偏导数。

在这个例子中,q_values 是需要优化的变量,因此 loss.backward() 的目的是对 q_values 求梯度,而不是 next_q_values

后记

2024年12月13日11点04分于上海,在GPT4o大模型辅助下完成。


http://www.ppmy.cn/server/151010.html

相关文章

【Excel】单元格分列

目录 分列(新手友好) 1. 选中需要分列的单元格后,选择 【数据】选项卡下的【分列】功能。 2. 按照分列向导提示选择适合的分列方式。 3. 分好就是这个样子 智能分列(进阶) 高级分列 Tips: 新手推荐基…

.NET 技术系列 | 通过CreatePipe函数创建管道

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

数据库实验四(SQL 数据库更新操作与完整性约束实践)

一、实验目的 本次实验主要目的包括: 熟练掌握使用 SQL 语句实现更新操作的方法,能够对数据库中的数据进行准确修改。深刻认识完整性约束对数据库的重要性,理解其在维护数据准确性、一致性和可靠性方面的关键作用。精通在 MySQL 中对完整性约…

【Python】paddleocr快速使用及参数详解

文章目录 1. paddleocr快速使用1.1 使用默认模型路径1.2 设定模型路径 2. PaddleOCR其他参数介绍PaddleOCR模型推理参数解释 其它相关推荐: PaddleOCR模型训练及使用详细教程 官方网址:https://github.com/PaddlePaddle/PaddleOCR PaddleOCR是基于Paddle…

Java从入门到工作2 - IDEA

2.1、项目启动 从git获取到项目代码后,用idea打开。 安装依赖完成Marven/JDK等配置检查数据库配置启动相关服务 安装依赖 如果个别依赖从私服下载不了,可以去maven官网下载补充。 如果run时提示程序包xx不存在,在项目目录右键Marven->Re…

从模型到视图:如何用 .NET Core MVC 构建完整 Web 应用

MVC模式自出现以来便成为了 Web 开发的基石,它通过将数据、业务逻辑与用户界面分离,使得应用更加清晰易于维护,然而随着前端技术的飞速发展和框架如 React、Vue、Angular 等的崛起,许多开发者开始倾向于前后端分离的方式&#xff…

elk部署与实战案例

**ELK Stack** 是一个非常强大的日志处理和分析平台,由 **Elasticsearch**、**Logstash** 和 **Kibana** 三个组件组成。它被广泛应用于日志收集、搜索、分析和可视化。ELK 可以处理大量数据,并帮助用户从中提取有价值的信息。以下是一个从部署到实际应用…

AirSim 无人机利用姿态文件获取图片

之前我们得到了随机姿态下无人机获得的不同场景图像,我们输出了无人机随机的姿态信息到poses.csv文件中,现在我们想要复现我们的结果,就要利用我们之前输出的姿态文件来获取图像。 无人机利用姿态文件获取图片的代码如下: impor…