基于强化学习算法玩CartPole游戏

embedded/2024/10/22 18:43:34/

什么事CartPole游戏

CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆)上安装在小车(或称为平衡车)上的水平移动,使杆子保持竖直直立的状态。

有两个动作(action):

左移(0)

右移(1)

四个状态(state): 1. 小车在轨道上的位置 2. 杆子与竖直方向的夹角 3. 小车速度 4. 角度变化率

神经网络设计

1、强化学习的训练网络cartpole_train.py

import  gym
import pygame
import time
import random
import torch
from torch.distributions import Categoricalfrom torch import nn, optim
import torch.nn.functional as Fdef compute_policy_loss(n, log_p):r = list()#构造奖励r列表for i in range(n, 0 ,-1):r.append(i *1.0)r = torch.tensor(r)r = (r - r.mean()) / r.std() #进行标准化处理loss = 0#计算损失函数for pi, ri in zip(log_p, r):loss += -pi * rireturn  lossclass CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(in_features = 4, out_features = 128)self.fc2 = nn.Linear(128, 2) #输出为神经元个数为2表示,向左和向向右self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)#使用softmax决策最终的行动,是向左还是右return F.softmax(x, dim=1)if __name__ == '__main__':env = gym.make("CartPole-v1") #启动环境env.reset(seed= 543)torch.manual_seed(543)policy = CartPolePolicy() #定义模型optimizer = optim.Adam(policy.parameters(), lr = 0.01) #优化器#我们一共最多训练1000个回合#每个回合最多行动10000次#当某一回合的游戏步数超过5000时,就认为完成训练max_episod = 1000 #最大游戏回合数max_action = 10000 #每回合最大行动数max_steps = 5000 #完成训练的步数for episod in range(1, max_episod + 1):# 对于每一轮循环,都要重新启动一次游戏环境state, _ = env.reset()step = 0log_p = list()for step in range(1, max_action + 1):state = torch.from_numpy(state).float().unsqueeze(0)probs = policy(state) #计算神经网络给出的行动概率# 基于网络给出的概率分布,随机选择行动m = Categorical(probs)# 这里并不是直接使用概率较大的行动,而是通过概率分布生成action, 这样可以进一步探索低概率行动action = m.sample()state, _, done, _, _ = env.step(action.item())if done:break #表示跳出该for循环log_p.append(m.log_prob(action)) #保存每次行动对应的概率分布if step > max_steps: #当step大于最大步数时print(f"Done! last episode {episod} Run steps {step}")break #跳出循序,结束训练#每一回合游戏,都会做一次梯度下降算法optimizer.zero_grad()loss = compute_policy_loss(step, log_p)loss.backward()optimizer.step()if episod % 10 ==0:print(f"Episode {episod} Run step {step}")#保存模型torch.save(policy.state_dict(), f"cartpole_policy.pth")

2、验证:cartpole_eval.py

import  gym
import pygame
import torch.nn as nn
import torch.nn.functional as F
import time
import torch
class CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(4, 128)self.fc2 = nn.Linear(128, 2)self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)return F.softmax(x, dim=1)if __name__ == '__main__':pygame.init() #初始化pygame#使用gym, 创建一个artPole游戏的运行环境,这个环境是提供给人类玩家使用的env = gym.make('CartPole-v1', render_mode = "human")state, _ =env.reset()#使用env.reset重置环境后,会得到CartPole游戏中关键参数statecart_position = state[0] #小车位置cart_speed = state[1] #小车速度pole_angle = state[2] #杆的角度pole_speed = state[3] #杆的尖端速度#加载网络policy = CartPolePolicy()policy.load_state_dict(torch.load("cartpole_policy.pth"))policy.eval()start_time =time.time()max_action =1000 #设置游戏最大执行次数#最多执行1000次方向键,游戏就可以通关结束step = 0fail = Falsefor step in range(1, max_action + 1):#首先使用time.sleep,使游戏暂停0.3s,用于人的反应,觉得自己反应慢可以设置更长时间# time.sleep(0.3)#小车的控制方式,通过神经网络,来决定小车的运动方向#将环境参数state转为张量state = torch.from_numpy(state).float().unsqueeze(0)#输入至网络模型,计算行动概率probsprobs = policy(state)#选取行动概率最大的行动action =torch.argmax(probs, dim = 1).item()state, _, done, _, _ = env.step(action) #done为True,表示杆倒了if done:fail = Truebreakprint(f"step = {step} action = {action} angle = {state[2]:.2f}  position = {state[0]:.2f}")end_time = time.time()game_time = end_time - start_timeif fail:print(f"Game over ,you play {game_time:.2f} seconds, {step} steps.")else:print(f"Congratulations! you play  {game_time:.2f} seconds, {step} steps.")env.close()

视频讲解:

什么是reinforce强化学习算法,基于强化学习玩CartPole游戏_哔哩哔哩_bilibili


http://www.ppmy.cn/embedded/90998.html

相关文章

2024年8月1日(前端服务器的配置以及tomcat环境的配置)

[rootstatic ~]# cd eleme_web/ [rootstatic eleme_web]# cd src/ [rootstatic src]# ls views/ AboutView.vue HomeView.vue [rootstatic src]# vim views/HomeView.vue [rootstatic src]# nohup npm run serve nohup: 忽略输入并把输出追加到"nohup.out" 构建项目…

批量按照原图片名排序修改图片格式为00000001.png(附代码)

💪 专业从事且热爱图像处理,图像处理专栏更新如下👇: 📝《图像去噪》 📝《超分辨率重建》 📝《语义分割》 📝《风格迁移》 📝《目标检测》 📝《暗光增强》 &a…

DataX介绍

DataX是阿里巴巴集团开源的一款高效、易用的数据同步工具,广泛应用于大数据领域的数据迁移、数据备份、数据同步等多种场景。以下是对DataX的详细介绍,包括其特点、架构、使用场景、优缺点以及安装部署等方面。 一、DataX概述 1. 定义与背景 DataX是阿…

java里CMS(Concurrent Mark-Sweep)和G1(Garbage First)垃圾回收器区别

CMS(Concurrent Mark-Sweep)和G1(Garbage First)是两种不同的Java垃圾回收器,它们有着不同的设计目标和实现方式。下面详细解释它们的区别。 CMS垃圾回收器 CMS(Concurrent Mark-Sweep)是JDK 1…

Moretl 单向文件同步工具

使用咨询: 扫码添加QQ 永久免费: Gitee下载最新版本 使用说明: CSDN查看使用说明 功能: 定时(全量采集or增量采集) SCADA,MES等系统采集工控机,办公电脑文件. 优势1: 开箱即用. 解压直接运行.插件集成下载. 优势2: 批管理设备. 配置均在后台配置管理. 优势3: 无人值守 采集端…

php开发的在线客服系统,全开源无加密,支持微信客服对接

介绍: 在网络上找了一圈“客服系统源码”,配置测试了一下,发现所有的“客服系统源码”基本都不能正常的使用。 所以更新了这份新源码,亲测是可以正常使用的,因此顺便也给大家简单分享一下!好东西&#xf…

golang实现切换元素互换的7种方式

方法1:使用临时变量 package mainimport "fmt"func main() {a : []int{1, 2, 3, 4}// 使用临时变量交换tmp : a[0]a[0] a[1]a[1] tmpfmt.Println("a:", a) }方法2:使用多重赋值 package mainimport "fmt"func main() …

Redis持久化

AOF 文件重写参数