【机器学习】机器学习的基本分类-强化学习-REINFORCE 算法

ops/2024/12/31 0:24:30/

REINFORCE 算法

REINFORCE 是一种基于策略梯度的强化学习算法,直接通过采样环境中的轨迹来优化策略。它是策略梯度方法的基础实现,具有简单直观的优点。


核心思想

  1. 目标函数

    • 最大化策略的期望回报:

                              ​​​​​​​         J(\theta) = \mathbb{E}_{\pi_\theta} \left[ \sum_{t=0}^T \gamma^t R_t \right]
    • 通过优化策略参数 θ,使累积回报 J(θ) 最大化。
  2. 策略梯度定理

    • 策略梯度为:

              ​​​​​​​        ​​​​​​​         \nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot G_t \right]
    • 其中 G_t = \sum_{k=t}^T \gamma^{k-t} R_k​ 是从时间步 t 开始的累积回报。
  3. 梯度估计

    • 使用采样方法估计梯度:

              ​​​​​​​         \nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^N \sum_{t=0}^{T_i} \nabla_\theta \log \pi_\theta(a_t^i | s_t^i) \cdot G_t^i
    • 其中 N 是采样的轨迹数量。

算法流程

  1. 初始化

    • 随机初始化策略参数 θ。
  2. 采样轨迹

    • 使用当前策略 \pi_\theta(a|s)与环境交互,生成 N 条轨迹。
  3. 计算回报

    • 对每条轨迹计算累积回报 G_t
  4. 计算梯度

    • 根据策略梯度定理计算梯度 \nabla_\theta J(\theta)
  5. 更新策略参数

    • 使用梯度上升更新策略参数:

              ​​​​​​​        ​​​​​​​                \theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)
  6. 迭代

    • 重复上述步骤,直至策略收敛。

伪代码

Initialize policy network with random weights θ
for episode in range(max_episodes):Generate a trajectory using πθCompute returns G_t for each step in the trajectoryfor each step in the trajectory:Compute policy gradient:∇θ J(θ) = ∇θ log πθ(a_t | s_t) * G_tUpdate policy network parameters:θ ← θ + α * ∇θ J(θ)


关键特性

  1. 无基线版本

    • 直接使用累积回报 G_t 作为更新目标。
    • 高方差:每条轨迹的回报差异可能很大,导致梯度估计的不稳定性。
  2. 基线改进

    • 减少方差的常用方法是在梯度中引入基线 b(s),更新规则变为:

               ​​​​​​​        \nabla_\theta J(\theta) = \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot (G_t - b(s_t))
    • 其中 b(st)b(s_t)b(st​) 通常是状态值函数 V(s_t) 的估计值。

优缺点

优点
  1. 实现简单

    • 通过采样轨迹即可直接优化策略。
  2. 适用于复杂策略

    • 可以学习高维连续动作或多样化策略。
  3. 灵活性

    • 可结合多种改进(如基线、Actor-Critic 方法)。
缺点
  1. 高方差

    • 回报 G_t 的方差较高,导致策略收敛较慢。
  2. 数据利用效率低

    • 每次更新仅使用一次采样的轨迹。
  3. 不稳定

    • 需要仔细调整学习率和其他超参数以确保收敛。

应用场景

  1. 游戏 AI

    • 用于优化游戏智能体的策略。
  2. 机器人控制

    • 优化机械臂或移动机器人在连续动作空间中的行为。
  3. 推荐系统

    • 动态优化用户推荐的长期回报。
  4. 金融交易

    • 在复杂的交易环境中设计交易策略。

改进方法

  1. 基线函数

    • 减少策略梯度的方差,提高更新的稳定性。
  2. Actor-Critic

    • 结合值函数的 Actor-Critic 方法,通过同时学习值函数和策略,进一步提高效率。
  3. Trust Region Policy Optimization (TRPO)

    • 限制策略更新幅度,确保每次更新的稳定性。
  4. Proximal Policy Optimization (PPO)

    • 通过裁剪策略更新的范围,兼顾效率和稳定性。

代码示例(简化版)

以下是一个 Python 示例,使用 NumPy 实现 REINFORCE:

import numpy as np# 环境接口
class Environment:def reset(self):# 返回初始状态passdef step(self, action):# 执行动作,返回 (下一状态, 奖励, 是否终止)pass# 策略网络 (简单线性模型)
class PolicyNetwork:def __init__(self, state_dim, action_dim):self.weights = np.random.randn(state_dim, action_dim)def predict(self, state):logits = np.dot(state, self.weights)return np.exp(logits) / np.sum(np.exp(logits))  # Softmaxdef update(self, grads, learning_rate):self.weights += learning_rate * grads# REINFORCE 算法
def reinforce(env, policy, episodes, learning_rate):for episode in range(episodes):state = env.reset()trajectory = []# 采样轨迹while True:probs = policy.predict(state)action = np.random.choice(len(probs), p=probs)next_state, reward, done = env.step(action)trajectory.append((state, action, reward))state = next_stateif done:break# 计算回报G = 0grads = np.zeros_like(policy.weights)for t, (state, action, reward) in enumerate(reversed(trajectory)):G = reward + 0.99 * Ggrad = np.zeros_like(policy.weights)grad[:, action] = stategrads += grad * (G - np.mean([x[2] for x in trajectory]))  # 使用基线# 更新策略policy.update(grads, learning_rate)


http://www.ppmy.cn/ops/143825.html

相关文章

Python从0到100(七十八):神经网络--从0开始搭建全连接网络和CNN网络

前言: 零基础学Python:Python从0到100最新最全教程。 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

STM32单片机芯片与内部33 ADC 单通道连续DMA

目录 一、ADC DMA配置——标准库 1、ADC配置 2、DMA配置 二、ADC DMA配置——HAL库 1、ADC配置 2、DMA配置 三、用户侧 1、DMA开关 (1)、标准库 (2)、HAL库 2、DMA乒乓 (1)、标准库 &#xff…

FFmpeg 4.3 音视频-多路H265监控录放C++开发二十一.3,RTCP协议, RTCP协议概述,RTCP协议详情

官方文档参考:RFC 3550 - RTP: A Transport Protocol for Real-Time Applications

uniapp v-tabs修改了几项功能,根据自己需求自己改

根据自己的需求都可以改 这里写自定义目录标题 1.数组中的名字过长,导致滑动异常2.change 事件拿不到当前点击的数据,通过index在原数组中查找得到所需要的id 各种字段麻烦3.添加指定下标下新加红点显示样式 1.数组中的名字过长,导致滑动异常…

STM32内部flash分区

STM32的内部Flash根据型号和容量的不同,分区方式可能有所差异,但通常都包含以下几个主要部分: 主存储器:这是内部Flash的主要部分,用于存放程序代码和数据常量。在STM32F4系列中,主存储器被划分为多个扇区…

Java面试被问到GC相关问题如何回答?

前言 众所周知,Java在运行时将内存划分为五个主要部分:程序计数器、虚拟机栈、本地方法栈、堆以及方法区。值得注意的是,程序计数器、虚拟机栈和本地方法栈这三个区域的内存管理相对简单,它们的生命周期与线程同步,即…

防止私接小路由器

电脑获取到IP地址不是DHCP服务器的IP地址段,导致整个公司网络瘫痪,这些故障现象通常80%原因是私接小路由器导致的,以下防止私接小路由器措施。 一、交换机配置DHCP Sooping DHCP snooping是一种DHCP安全特性,用于防止非法设备获…

计算机网络 八股青春版

什么是HTTP?HTTP和HTTPS的区别 HTTP HTTP是超文本运输协议,是一种无状态(每次请求都是独立的)的应用层协议。用于在客户端和服务器之间传输超文本数据(如HTML文件)。默认端口是80数据以明文形式传输&#…