【强化学习】玩转Atari-Pong游戏

news/2025/1/11 22:36:26/
如果您感觉项目还不错,请您点个fork支持一下,谢谢qwq

在这里插入图片描述

玩转Atari-Pong游戏

  • Atari: 雅达利,最初是一家游戏公司,旗下有超过200款游戏,不过已经破产。在强化学习中,Atari游戏是经典的实验环境之一,因此,本项目旨在学习使用强化学习算法玩Atari游戏。
  • Pong: 1972年,雅达利(Atari)创办人布什内尔及达布尼推出首款街机Pong,最初仅生产12部,以简单点线接口仿真打乒乓球的游戏,奠定街机始祖地位。该游戏的简略版英文描述为:

You control the right paddle, you compete against the left paddle controlled by the computer. You each try to keep deflecting the ball away from your goal and into your opponent’s goal.

翻译成中文就是:

你控制右边的球拍,你与电脑控制的左边的球拍竞争。你们各自努力使球不断偏离自己的目标,进入对手的目标。

游戏示意图:

在这里插入图片描述

从该动态图可以看出,不经训练的右侧球拍完全打不过左侧球拍的,因此我们的目标就是训练右侧球拍使其战胜左侧球拍。

  • Pong环境的状态、动作与奖励:

    • 状态:Pong环境提供的状态默认是Box(210, 160, 3),也就是3通道的彩色图
    • 动作:Pong-v0和Pong-V4版本返回的动作都是Discrete(6),也就是离散的6个动作。网上有介绍:Pong 环境介绍,提到其实6个动作中有用的只有3个,可以参考该介绍,加深理解。
    • 奖励:奖励有三种状态:-1,0,1,分别表示右侧未接到球;中间过程;左侧未接到球。
  • 训练结果展示:


    在这里插入图片描述

我们同时提供了动态图Pong-v4_trained.gif,因为该动态图超过10MB,无法展示,可自行下载观看。

1.Atari环境的安装

在运行man.ipynb之前,请先运行help.ipynb生成我们的依赖环境!!!

目前Ai studio平台并没有内嵌Atari环境,需要我们自行安装,为避免反复安装,我们将安装过程写到help.ipynb。可运行我们的help.ipynb进行持久化安装。主要的安装命令如下所示:

  1. ! pip install atari_py==0.2.6 -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  2. ! pip install ale-py -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  3. ! pip install pyglet -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  4. ! pip install autorom -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  5. ! pip install AutoROM.accept-rom-license -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  6. !rar x Roms.rar
  7. !python -m atari_py.import_roms ROMS

其中需要注意:第4、5条安装命令可能无法一次成功,多运行几次即可;第6条命令一个项目仅运行一次即可。

2.导入我们的依赖包

注意要先将我们自行安装的Atari环境加入到系统中,即

sys.path.append(‘/home/aistudio/external-libraries’)

import sys 
sys.path.append('/home/aistudio/external-libraries')import gym
import numpy as np
import time
import matplotlib.pyplot as plt
import paddle
import os
from collections import deque,Counter
from visualdl import LogWriter
import copy
from collections import Counter
from matplotlib import animation
from PIL import Image

3.环境测试

检测我们是否可以成功加载环境,并查看我们的状态空间和动作空间

env = gym.make('Pong-v4')
print(env.observation_space)
print(env.action_space)
Box(210, 160, 3)
Discrete(6)

4.状态的预处理

在这里我们首先定义了状态的预处理函数preprocess,该函数说明如下:

  • 输入:状态,Pong环境给出的不加任何处理的环境状态,Box(210, 160, 3)
  • 处理:处理过程可以看我们下边的过程图片。
    • 裁剪:将实际没有用的部分去除,主要是Pong环境返回的图像的上边和下边的部分
    • 下采样:在保留特征的前提下进行像素点的缩减
    • 擦除背景,在我们下采样后,环境的背景其实是有两种(109,144),这个也需要多观察才能看出,可以参考我们给出的示例图。
    • 转为灰度图:非0即1,我们仅保留左右球拍和球,减少不必要因素的干扰
    • 打平:将图像打平,进而只使用线性层进行特征学习

4.1 preprocess函数

def preprocess(image):""" 预处理 210x160x3 uint8 frame into 6400 (80x80) 1维 float vector """image = image[35:195]  # 裁剪image = image[::2, ::2, 0]  # 下采样,缩放2倍image[image == 144] = 0  # 擦除背景 (background type 1)image[image == 109] = 0  # 擦除背景 image[image != 0] = 1  # 转为灰度图,除了黑色外其他都是白色return image.astype(np.float).ravel() #打平,(6400,)

4.2 对preprocess函数进行可视化说明,展示中间过程

def show_image(status):status1=status[35:195] #裁剪有效区域status2 = status1[::2, ::2, 0] #下采样,缩减# 观察我们的像素点构成def see_color(status):allcolor=[]for i in range(80):allcolor.extend(status[i])dict_color=Counter(allcolor)print("像素点构成: ",dict_color)see_color(status2)# 观察好像素点后,擦除背景def togray(image_in):image=image_in.copy()image[image == 144] = 0  # 擦除背景 (background type 1)image[image == 109] = 0  # 擦除背景image[image != 0] = 1  # 转为灰度图,除了黑色外其他都是白色return imagestatus3=togray(status2)# 可视化我们的操作中间图def show_status(list_status):fig = plt.figure(figsize=(8, 32), dpi=200)plt.subplots_adjust(left=None, bottom=None, right=None, top=None,wspace=0.3, hspace=0)for i in range(len(list_status)):plt.subplot(1,len(list_status),i+1)plt.imshow(list_status[i],cmap=plt.cm.binary)plt.show()show_status([status,status1,status2,status3])

4.3 背景为109的preprocess展示

status = env.reset() #原始图
show_image(status)
像素点构成:  Counter({109: 6382, 101: 16, 53: 2})/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述

4.4 背景为144的preprocess展示

for i in range(200):action=env.action_space.sample()status,reward,done,info=env.step(action)show_image(status)
像素点构成:  Counter({144: 6366, 213: 16, 92: 16, 236: 2})

在这里插入图片描述

5.模型的定义,简单的全连接层

class Model(paddle.nn.Layer):""" 使用全连接网络.参数:obs_dim (int): 观测空间的维度.act_dim (int): 动作空间的维度."""def __init__(self, obs_dim, act_dim):super(Model, self).__init__()hid1_size = 256hid2_size = 64self.fc1 = paddle.nn.Linear(obs_dim, hid1_size)self.fc2 = paddle.nn.Linear(hid1_size, hid2_size)self.fc3 = paddle.nn.Linear(hid2_size, act_dim)def forward(self, obs): h1 = paddle.nn.functional.relu(self.fc1(obs))h2 = paddle.nn.functional.relu(self.fc2(h1))prob = paddle.nn.functional.softmax(self.fc3(h2), axis=-1)return prob

6.策略梯度算法

强化学习的经典算法之一,可以参考我们之前的项目【强化学习】REINFORCE算法

在这里我们仅定义预测更新两个函数。

# 梯度下降算法
class PolicyGradient():def __init__(self, model, lr):self.model = modelself.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=self.model.parameters())def predict(self, obs):prob = self.model(obs)return probdef learn(self, obs, action, reward):prob = self.model(obs)#print("prob: ",prob)log_prob = paddle.distribution.Categorical(prob).log_prob(action)loss = paddle.mean(-1 * log_prob * reward)self.optimizer.clear_grad()loss.backward()self.optimizer.step()return loss

7.策略梯度智能体

  • 我们默认从文件中加载参数进行训练,因为PG算法+Pong环境的训练需要大量的时间,一次直接训练完成很耗时;当然我们支持从0开始训练
  • sample: 在训练时调用的函数,带探索
  • predict:在预测(测试)时调用的函数,不带探索
  • learn:更新函数
  • save和load:保存参数和加载参数。注意:这里我们保存了优化器的参数,但是在加载是并未加载上优化器的参数,有报错,未进行修复,但是不加载优化器参数几乎不影响我们的训练的。(这里我其实不太明白到底需不需加载优化器参数,还望大佬不吝赐教,拜谢)
class Agent():def __init__(self, algorithm):self.alg=algorithmif os.path.exists("./savemodel"):print("开始从文件加载参数....")try:self.load()print("从文件加载参数结束....")except:print("从文件加载参数失败,从0开始训练....")def sample(self, obs):""" 根据观测值 obs 采样(带探索)一个动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)#print("prob:",prob)prob = prob.numpy()act = np.random.choice(len(prob), 1, p=prob)[0]  # 根据动作概率选取动作return actdef predict(self, obs):""" 根据观测值 obs 选择最优动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)act = prob.argmax().numpy()[0]  # 根据动作概率选择概率最高的动作return actdef learn(self, obs, act, reward):""" 根据训练数据更新一次模型参数"""act = np.expand_dims(act, axis=-1)reward = np.expand_dims(reward, axis=-1)obs = paddle.to_tensor(obs, dtype='float32')act = paddle.to_tensor(act, dtype='int32')reward = paddle.to_tensor(reward, dtype='float32')#print("gggggggggggggg",obs.shape,act.shape,reward.shape)loss = self.alg.learn(obs, act, reward)return loss.numpy()[0]def save(self):paddle.save(self.alg.model.state_dict(),'./savemodel/PG-Pong_net.pdparams')paddle.save(self.alg.optimizer.state_dict(), "./savemodel/opt.pdopt")def load(self):# 加载网络参数model_state_dict=paddle.load('./savemodel/PG-Pong_net.pdparams')self.alg.model.set_state_dict(model_state_dict)# # 加载优化器参数# optimizer_state_dict=paddle.load("./savemodel/opt.pdopt")# self.alg.optimizer.set_state_dict(optimizer_state_dict)

8. 训练与测试

8.1 定义训练函数

# 训练一个episode
def run_train_episode(agent, env):obs_list, action_list, reward_list = [], [], []obs = env.reset()while True:obs = preprocess(obs)  # from shape (210, 160, 3) to (6400,)obs_list.append(obs)action = agent.sample(obs)action_list.append(action)obs, reward, done, info = env.step(action)# if reward!=0:#     print("reward: ",action)reward_list.append(reward)if done:breakreturn obs_list, action_list, reward_list

8.2 定义预测函数

# 评估 agent, 跑 5 个episode,总reward求平均
def run_evaluate_episodes(agent, env, render=False):eval_reward = []for i in range(5):obs = env.reset()episode_reward = 0while True:obs = preprocess(obs)  # from shape (210, 160, 3) to (6400,)action = agent.predict(obs)obs, reward, isOver, _ = env.step(action)episode_reward += rewardif render:env.render()if isOver:breakeval_reward.append(episode_reward)return np.mean(eval_reward)

8.3 定义奖励处理函数

进行奖励衰减操作,衰减因子gamma默认为0.99

def calc_reward_to_go(reward_list, gamma=0.99):"""calculate discounted reward"""reward_arr = np.array(reward_list)for i in range(len(reward_arr) - 2, -1, -1):# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1reward_arr[i] += gamma * reward_arr[i + 1]# normalize episode rewardsreward_arr -= np.mean(reward_arr)reward_arr /= np.std(reward_arr)return reward_arr

8.4 训练与预测的主函数

便于演示,我们仅进行100次的继续训练,读者可自行增加次数以获得更好的训练效果

def main():env = gym.make('Pong-v4')obs_dim = 80 * 80act_dim = env.action_space.nprint('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))# 根据parl框架构建agentLEARNING_RATE = 5e-4model = Model(obs_dim=obs_dim, act_dim=act_dim)alg = PolicyGradient(model, lr=LEARNING_RATE)agent = Agent(alg)twriter=LogWriter('./logs/PG_Pong')for i in range(100): # default 3000obs_list, action_list, reward_list = run_train_episode(agent, env)twriter.add_scalar('reward',sum(reward_list),i)if i % 50 == 0:print("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))batch_obs = np.array(obs_list)batch_action = np.array(action_list)batch_reward = calc_reward_to_go(reward_list)#print("ggggggggggggg",batch_obs.shape)agent.learn(batch_obs, batch_action, batch_reward)last_test_total_reward=0if (i + 1) % 100 == 0:# render=True 查看显示效果total_reward = run_evaluate_episodes(agent, env, render=False)print('Test reward: {}'.format(total_reward))# save the parametersif last_test_total_reward<total_reward:last_test_total_reward=total_rewardagent.save()# 运行整个程序
main()
obs_dim 6400, act_dim 6W1022 22:01:06.998914   174 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1022 22:01:07.003042   174 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.开始从文件加载参数....
从文件加载参数结束....
Episode 0, Reward Sum 14.0.
Episode 50, Reward Sum 8.0.
Test reward: 12.0

9.使用训练好的网络进行测试并生成动图

9.1 gif动图生成函数

def save_frames_as_gif(frames, filename):#Mess with this to change frame sizeplt.figure(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100), dpi=300)patch = plt.imshow(frames[0])plt.axis('off')def animate(i):patch.set_data(frames[i])anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)anim.save(filename, writer='pillow', fps=60)

9.2 从文件加载模型参数

model=Model(6400,6)
model_state_dict=paddle.load("./savemodel/PG-Pong_net.pdparams")
model.set_state_dict(model_state_dict)

9.4 使用训练好的模型进行测试并保存过程为动图

env=gym.make('Pong-v4')state=env.reset()
frames = []
done=0
i=0
reward_list=[]
while not done:frames.append(env.render(mode="rgb_array"))obs = preprocess(state)obs = paddle.to_tensor(obs, dtype='float32')prob = model(obs)action = prob.argmax().numpy()[0]next_state,reward,done,_=env.step(action)if reward!=0:reward_list.append(reward)print(i,"   ",reward,done)state=next_statei+=1reward_counter=Counter(reward_list)
print(reward_counter)
print("你的得分为:",reward_counter[1.0],'对手得分为:',reward_counter[-1.0])
if reward_counter[1.0]>reward_counter[-1.0]:print("恭喜您赢了!!!")
else:print("惜败,惜败,训练一下智能体网络再来挑战吧QWQ")save_frames_as_gif(frames, filename="Pong-v4_trained.gif")env.close()
199     1.0 False
732     1.0 False
937     1.0 False
1547     1.0 False
1676     1.0 False
1877     1.0 False
2165     1.0 False
2451     1.0 False
2575     1.0 False
2705     1.0 False
2995     1.0 False
3125     1.0 False
3331     1.0 False
3454     1.0 False
3584     1.0 False
3793     1.0 False
4885     1.0 False
5096     1.0 False
5698     1.0 False
5992     1.0 False
6202     1.0 True
Counter({1.0: 21})
你的得分为: 21 对手得分为: 0
恭喜您赢了!!!

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dsXN1i1W-1667103194205)(main_files/main_37_1.png)]

10. 总结

本项目参考自飞桨PARL,鼓励大家给点点stars
在这里插入图片描述

本项目目前通过5000+回合的训练,我们的智能体已经学会通过快速抖动法取得游戏的胜利了,但是大概率还不能完全碾压,后续有时间会继续训练或采取更加高效的算法进行改进。然后,这是我的第一个Atari游戏项目,之前都在在经典的控制游戏下进行实验,环境的转变使得学习的难度也上升,训练时间也在增加,学到的东西也在增加,挺好的…还请大佬多多指教,小黑还有很多路要走,嘿嘿!

之前的强化学习项目有:

  • DQN+CartPole-v0
  • A2C+CartPole-v0
  • DDPG+Pendulum-v0
  • TD3+Pendulum-v0
  • REINFORCE+CartPole-v0
  • PPO+CartPole-v0
  • SAC+Pendulum-v0

欢迎大家来交流学习!!!
此文章为搬运
原项目链接


http://www.ppmy.cn/news/967274.html

相关文章

基于Egret的OPPO小游戏接入

参考文档&#xff1a;OPPO小游戏打包官方文档 前提 安装了 node 环境&#xff0c;建议安装 8.x 稳定版本 [node官网&#xff1a;https://nodejs.org/en/]开发 Cocos Creator 游戏&#xff0c;需要升级到2.0.6及以上版本开发 Laya 游戏&#xff0c;需要 laya air 使用1.7.19或1…

逆向新手,经典扫雷游戏确定雷区地址的5个方法

逆向新手&#xff0c;经典扫雷游戏确定雷区地址的几个方法 前言 逆向新手&#xff0c;经典扫雷游戏确定雷区地址的几个方法 一、通过相关的数据区来确定 结合CE实现 首先从数据结构去考虑&#xff0c;思考‘’雷‘’与‘’非雷‘’数据&#xff08;还包括不同等级的窗口的高…

C++学习笔记1

Hello World程序的组成部分 可以分为两部分&#xff1a; &#xff08;1&#xff09;以#开头的是预处理器编译指令 &#xff08;2&#xff09;int main() 开头的是程序的主体 预处理编译指令#include 定义&#xff1a;预处理器是一个在编译前运行的工具 #include 是让预处理器获…

计算机网络军训口号,关于物联网的军训口号

每个学校都要一个霸气外露的军训口号&#xff0c;关于物联网的军训口号有哪些。以下是小编分享给大家的关于关于物联网的军训口号&#xff0c;希望大家喜欢! 关于物联网的军训口号精选 1. 继承人民军队光荣传统和优良作风 为民族复兴刻苦学习 2. 发扬集体主义和革命英雄主义精神…

计算机网络军训口号,军训各营连口号

十多天的军训生活已经接近尾声了&#xff0c;同学们都在训练场上紧张地为阅兵进行着操练&#xff0c;他们步伐整齐&#xff0c;口号嘹亮&#xff0c;气宇轩昂。 小编走访军训旅一团训练场&#xff0c;收集了来自一团各个营连的口号。这些口号风格各异&#xff0c;却都展现出同学…

计算机系统军训口号,计算机专业军训口号_3篇(共4页)1000字.docx

计算机专业军训口号  计算机专业军训口号精选  1.认真学习刻苦训练;级班能文能武!  2.好好学习天天向上;团结进取争创一流!  3.团结进取奋力拼搏;齐心协力共铸辉煌!  4.超越自我挑战极限;团结互助勇创佳绩!  5.天骄十九不懈追求;勇往直前争创一流!  6.军中骄子校…

计算机学院军训横幅,2020大学军训横幅标语句子精选100句

认真学习&#xff0c;刻苦训练&#xff0c;文武兼备&#xff0c;百炼成钢。今天小编就给大家整理了大学军训横幅标语句子&#xff0c;希望对大家的工作和学习有所帮助&#xff0c;欢迎阅读! 【1】大学军训横幅标语句子 1、团结一致&#xff0c;勇争第一。 2、流血流汗不流泪&am…

计算机网络军训口号,计算机军训口号

【计算机专业军训口号】 1. 争当训练标兵&#xff0c;共创先进连队。 2. 不经历风雨 &#xff0c;怎么见彩虹。 3. 同心同德求实创新齐育桃李芳天下&#xff0c;自律自强奋发进取共添德艺馨未来。 4. 团结一心&#xff0c;努力拼搏。 5. 明德尚行&#xff0c;矢志报国&#xff…