训练强化学习的经验回放策略:experience replay

news/2024/11/16 11:31:02/

经验回放:Experience Replay(训练DQN的一种策略)


优点:可以重复利用离线经验数据;连续的经验具有相关性,经验回放可以在离线经验BUFFER随机抽样,减少相关性;

超参数:Replay Buffer的长度;
∙ Find w by minimizing  L ( w ) = 1 T ∑ t = 1 T δ t 2 2 . ∙ Stochastic gradient descent (SGD): ∙ Randomly sample a transition,  ( s i , a i , r i , s i + 1 ) , from the buffer ∙ Compute TD error,  δ i . ∙ Stochastic gradient: g i = ∂ δ i 2 / 2 ∂ w = δ i ⋅ ∂ Q ( s i , a i ; w ) ∂ w ∙ SGD: w ← w − α ⋅ g i . \begin{aligned} &\bullet\text{ Find w by minimizing }L(\mathbf{w})=\frac{1}{T}\sum_{t=1}^{T}\frac{\delta_{t}^{2}}{2}. \\ &\bullet\text{ Stochastic gradient descent (SGD):} \\ &\bullet\text{ Randomly sample a transition, }(s_i,a_i,r_i,s_{i+1}),\text{from the buffer} \\ &\bullet\text{ Compute TD error, }\delta_i. \\ &\bullet\text{ Stochastic gradient: g}_{i}=\frac{\partial\delta_{i}^{2}/2}{\partial \mathbf{w}}=\delta_{i}\cdot\frac{\partial Q(s_{i},a_{i};\mathbf{w})}{\partial\mathbf{w}} \\ &\bullet\text{ SGD: w}\leftarrow\mathbf{w}-\alpha\cdot\mathbf{g}_i. \end{aligned}  Find w by minimizing L(w)=T1t=1T2δt2. Stochastic gradient descent (SGD): Randomly sample a transition, (si,ai,ri,si+1),from the buffer Compute TD error, δi. Stochastic gradient: gi=wδi2/2=δiwQ(si,ai;w) SGD: wwαgi.


注:实践中通常使用minibatch SGD,每次抽取多个经验,计算小批量随机梯度;
Replay Buffer代码实现如下:

@dataclass
class ReplayBuffer:maxsize: intsize: int = 0state: list = field(default_factory=list)action: list = field(default_factory=list)next_state: list = field(default_factory=list)reward: list = field(default_factory=list)done: list = field(default_factory=list)def push(self, state, action, reward, done, next_state):""":param state: 状态:param action: 动作:param reward: 奖励:param done::param next_state:下一个状态:return:"""if self.size < self.maxsize:self.state.append(state)self.action.append(action)self.reward.append(reward)self.done.append(done)self.next_state.append(next_state)else:position = self.size % self.maxsizeself.state[position] = stateself.action[position] = actionself.reward[position] = rewardself.done[position] = doneself.next_state[position] = next_stateself.size += 1def sample(self, n):total_number = self.size if self.size < self.maxsize else self.maxsizeindices = np.random.randint(total_number, size=n)state = [self.state[i] for i in indices]action = [self.action[i] for i in indices]reward = [self.reward[i] for i in indices]done = [self.done[i] for i in indices]next_state = [self.next_state[i] for i in indices]return state, action, reward, done, next_state

训练时的代码如下:

离线数据放到BUFFER里面:

#动作、状态、奖励、结束标志、下一状态
replay_buffer.push(state, action, reward, done, next_state)

训练时采样然后计算损失

bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)loss = agent.compute_loss(bs, ba, br, bd, bns)
loss.backward()
optimizer.step()
optimizer.zero_grad()

http://www.ppmy.cn/news/1010170.html

相关文章

SSE技术和WebSocket技术实现即时通讯

文章目录 一、SSE1.1 什么是SSE1.2 工作原理1.3 特点和适用场景1.4 API用法1.5 代码实现 二、WebSocket2.1 什么是WebSocket2.2 工作原理2.3 特点和适用场景2.4 API用法2.5 代码实现 三、SSE与WebSocket的比较 当涉及到实现实时通信的Web应用程序时&#xff0c;两种常见的技术选…

Tomcat 编程式启动 JMX 监控

通过这篇文章&#xff0c;我们可以了解到&#xff0c;利用 JMX 技术可以方便获取 Tomcat 监控情况。但是我们采用自研的框架而非大家常见的 SpringBoot&#xff0c;于是就不能方便地通过设置配置开启 Tomcat 的 JMX&#xff0c;——尽管我们也是基于 Tomcat 的 Web 容器&#x…

JAVA实现二分法查找算法

现实生活中经常会遇到将具有某个特征的元素选择出来&#xff0c;并找出对应的位置。 现在来一个小测验&#xff0c;在以数组【1&#xff0c;4&#xff0c;8&#xff0c;3&#xff0c;0&#xff0c;7&#xff0c;56】中找到8所在的位置&#xff0c;很明显大家可以通过直观的感受…

Linux 系统编程 开篇/ 文件的打开/创建

从本节开始学习关于Linux系统编程的知识&#xff01; 学习Linux的系统编程有非常多的知识点&#xff0c;在应用层面&#xff0c;很重要的一点就是学习如何“用代码操作文件来实现文件创建&#xff0c;打开&#xff0c;编辑等自动化执行” 那如何自动化实现对文件的创建&#…

学python的心得体会1000字,学python的心得体会2000字

这篇文章主要介绍了学python的心得体会2000字&#xff0c;具有一定借鉴价值&#xff0c;需要的朋友可以参考下。希望大家阅读完这篇文章后大有收获&#xff0c;下面让小编带着大家一起了解一下。 1. 初学者应该从简单的练习开始&#xff0c;先掌握基本的语法和概念&#xff0c;…

【网络基础实战之路】基于三个分公司的内网搭建并连接运营商的实战详解

系列文章传送门&#xff1a; 【网络基础实战之路】设计网络划分的实战详解 【网络基础实战之路】一文弄懂TCP的三次握手与四次断开 【网络基础实战之路】基于MGRE多点协议的实战详解 【网络基础实战之路】基于OSPF协议建立两个MGRE网络的实验详解 PS&#xff1a;本要求基于…

6年资深测试整理,接口测试总结,你不知道的都在这了...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 接口测试 是测试…

如何将 dubbo filter 拦截器原理运用到日志拦截器中?

业务背景 我们希望可以在使用日志拦截器时&#xff0c;定义属于自己的拦截器方法。 实现的方式有很多种&#xff0c;我们分别来看一下。 拓展阅读 java 注解结合 spring aop 实现自动输出日志 java 注解结合 spring aop 实现日志traceId唯一标识 java 注解结合 spring ao…