Python神经网络学习(七)--强化学习--使用神经网络

news/2024/11/24 10:49:18/

前言

前面说到了强化学习,但是仅仅是使用了一个表格,这也是属于强化学习的范畴了,毕竟强化学习属于在试错中学习的。

但是现在有一些问题,如果这个表格非常大呢?悬崖徒步仅仅是一个长12宽4,每个位置4个动作的表格而已,如果游戏是英雄联盟,那么多的位置,每个位置那么多的可能动作,画出一个表格简直是不可想象的。

但其实,如果把这个表格看作一个数学函数,他的输入是坐标,输出是一个动作(或者每个动作对应的价值):

那也就是说,只要我们有一个坐标,得到一个动作,中间什么过程是可以不用管的,还记得这篇文章中说过:神经元(函数)+神经元(函数) = 神经网络(人工神经网络),那么,中间这一块也就可以使用神经网络代替,这也就是深度强化学习。

论文(Playing Atari with Deep Reinforcement Learning)地址:https://arxiv.org/abs/1312.5602

 设置环境

注意:今天的环境代码我修改过了,跟上一篇的不一样,所以大家还是要先读一下环境代码。

本次环境代码中添加了对于棋盘大小的设置,修复了一些bug。

# -*- coding: utf-8 -*-
"""
作者:CSDN,chuckiezhu
作者地址:https://blog.csdn.net/qq_38431572
本文可用作学习使用,交流代码时需要附带本出处声明
"""import random
import numpy as npfrom gym import spaces"""
nrows0  1  2  3  4  5  6  7  8  9  10  11  ncols---------------------------------------
0  |  |  |  |  |  |  |  |  |  |  |   |   |---------------------------------------
1  |  |  |  |  |  |  |  |  |  |  |   |   |---------------------------------------
2  |  |  |  |  |  |  |  |  |  |  |   |   |---------------------------------------
3   * |       cliff                  | ^ |*: start pointcliff: cliff^: goal
"""class CustomCliffWalking(object):def __init__(self, stepReward: int=-1, cliffReward: int=-10, goalReward: int=10, col=12, row=4) -> None:self.sr = stepRewardself.cr = cliffRewardself.gr = goalRewardself.col = colself.row = rowself.action_space = spaces.Discrete(4)  # 上下左右self.reward_range = (cliffReward, goalReward)self.pos = np.array([row-1, 0], dtype=np.int8)  # agent 在3,0处出生,掉到悬崖内就会死亡,触发done和cliffRewardself.die_pos = []for c in range(1, self.col-1):self.die_pos.append([self.row-1, c])print("die pos: ", self.die_pos)print("goal pos: ", [[self.row-1, self.col-1]])self.reset()def reset(self, random_reset=False):"""初始化agent的位置random: 是否随机出生, 如果设置random为True, 则出生点会随机产生"""x, y = self.row-1, 0if random_reset:y = random.randint(0, self.col-1)if y == 0:x = random.randint(0, self.row-1)else:  # 除了正常坐标之外,还有一个不正常坐标:(3, 0)x = random.randint(0, self.row-2)# 严格来讲,cliff和goal不算在坐标体系内# agent 在3,0处出生,掉到悬崖内就会死亡,触发done和cliffRewardself.pos = np.array([x, y], dtype=np.int8)# print("reset at:", self.pos)def step(self, action: int) -> list[list, int, bool, bool, dict]:"""执行一个动作action:0: 上1: 下2: 左3: 右"""move = [np.array([-1, 0], dtype=np.int8), # 向上,就是x-1, y不动,np.array([ 1, 0], dtype=np.int8), # 向下,就是x+1, y不动,np.array([0, -1], dtype=np.int8), # 向左,就是y-1, x不动,np.array([0,  1], dtype=np.int8), # 向右,就是y+1, x不动,]new_pos = self.pos + move[action]# 上左不能小于0new_pos[new_pos < 0] = 0  # 超界的处理,比如0, 0 处向上或者向右走,处理完还是0,0# 上右不能超界if new_pos[0] > self.row-1:new_pos[0] = self.row-1  # 超界处理if new_pos[1] > self.col-1:new_pos[1] = self.col-1reward = self.sr  # 每走一步的奖励die = Falsewin = Falseinfo = {"reachGoal": False,"fallCliff": False,}if self.__is_pos_die(new_pos.tolist()):die = Trueinfo["fallCliff"] = Truereward = self.crelif self.__is_pos_win(new_pos.tolist()):win = Trueinfo["reachGoal"] = Truereward = self.grself.pos = new_pos  # 更新坐标return new_pos, reward, die, win, infodef __is_pos_die(self, pos: list[int, int]) -> bool:"""判断自己的这个状态是不是已经结束了"""return pos in self.die_posdef __is_pos_win(self, pos: list[int, int]) -> bool:"""判断自己的这个状态是不是已经结束了"""return pos in [[self.row-1, self.col-1],]

至于讲解这个环境,我觉得这个注释还是比较清楚的,如果有不明白的,请评论留言告知我。

制作网络

首先,我们先把自己代入表格,如果我们站到某个坐标,那么我们应该知道四个方向上的奖励,所以,网络可以有两种方式;

方式一、

网络输入是坐标和方向,输出是对应的奖励。

方式二、

网络输入是坐标,输出是四个方向对应的奖励。

这里我要来一句场外推理:方式一真的很麻烦,并且选择动作的时候,有多少个动作需要经过多少次网络。所以方式二是比较好的选择。


class Qac(nn.Module):def __init__(self, in_shape, out_shape) -> None:super(Qac, self).__init__()self.in_shape = in_shape  # 就是 智能体 现在的坐标self.action_space = out_shape  # 上0下1左2右3self.dense1 = nn.Linear(self.in_shape, self.action_space)# 输出就是每个动作的价值self.lrelu = nn.LeakyReLU()  # 换用tanhself.softmax = nn.Softmax(-1)def forward(self, x) -> torch.Tensor:x = self.dense1(x)return xdef sample_action(self, action_value: torch.Tensor, epsilon: float):"""从产生的动作概率中采样一个动作,利用epsilon贪心"""if random.random() < epsilon:# 随机选择action = random.randint(0, self.action_space-1)action = torch.tensor(action)else:action = torch.argmax(action_value)return actiondef load_model(self, modelpath):"""加载模型"""tmp = torch.load(modelpath)self.load_state_dict(tmp["model"])def save_model(self, modelpath):"""保存模型"""tmp = {"model": self.state_dict(),}torch.save(tmp, modelpath)

细心的人可能发现了,这个网络只有一层,非常简单,好像没有所谓的“特征提取”就直接到输出层了。这里有一个小技巧,就是我手动把坐标转成了onehot向量,可以认为是手动提取了特征。

def num_to_onehot(pos: torch.Tensor) -> torch.Tensor:"""把坐标转成one_hot向量"""n = int((pos[0] * 12 + pos[1]).item())return nn.functional.one_hot(torch.tensor(n), num_classes=48)

如果大家使用两层神经网络,直接输入坐标,中间层是48,然后是一个输出层,也可以, 但是我试了,训练很慢,效果不好。不如这样直接手动编码了。

训练

整个训练的代码我直接贴在这里了:

# -*- coding: utf-8 -*-
"""
利用DQN实现
"""
"""
作者:CSDN,chuckiezhu
作者地址:https://blog.csdn.net/qq_38431572
本文可用作学习使用,交流代码时需要附带本出处声明
"""
import os
import random
import torch
import numpy as np
from torch import nnfrom matplotlib import pyplot as pltfrom cliff_walking_env import CustomCliffWalkingnepisodes = 10000  # total 1w episodes
epsilon = 1.0  # epsilon greedy policy
epsilon_min = 0.05
epsilon_decay = 0.9975gamma = 0.9  # discount factor
lr = 0.001
random_reset = Falseseed = 42normalization = torch.tensor([3, 11], dtype=torch.float)sr = -1
cr = -10
gr = 10class Qac(nn.Module):def __init__(self, in_shape, out_shape) -> None:super(Qac, self).__init__()self.in_shape = in_shape  # 就是智能体现在的坐标self.action_space = out_shape  # 上0下1左2右3self.dense1 = nn.Linear(self.in_shape, self.action_space)# 输出就是每个动作的价值self.lrelu = nn.LeakyReLU()  # 换用tanhself.softmax = nn.Softmax(-1)def forward(self, x) -> torch.Tensor:x = self.dense1(x)return xdef sample_action(self, action_value: torch.Tensor, epsilon: float):"""从产生的动作概率中采样一个动作,利用epsilon贪心"""if random.random() < epsilon:# 随机选择action = random.randint(0, self.action_space-1)action = torch.tensor(action)else:action = torch.argmax(action_value)return actiondef load_model(self, modelpath):"""加载模型"""tmp = torch.load(modelpath)self.load_state_dict(tmp["model"])def save_model(self, modelpath):"""保存模型"""tmp = {"model": self.state_dict(),}torch.save(tmp, modelpath)def num_to_onehot(pos: torch.Tensor) -> torch.Tensor:"""把坐标转成one_hot向量"""n = int((pos[0] * 12 + pos[1]).item())return nn.functional.one_hot(torch.tensor(n), num_classes=48)def main():global epsilonrandom.seed(seed)torch.manual_seed(seed=seed)plt.ion()os.makedirs("./out/ff_DQN/")# cw = gym.make("CliffWalking-v0", render_mode="human")cw = CustomCliffWalking(stepReward=sr, goalReward=gr, cliffReward=cr)# 专程onehot了Q = Qac(in_shape=48, out_shape=cw.action_space.n)optimizer = torch.optim.Adam(Q.parameters(), lr=lr)loss_fn = torch.nn.MSELoss()win_1000 = []  # 记录最近一千场赢的几率total_win = 0for i in range(1, nepisodes+1):cw.reset(random_reset=random_reset)  # 重置环境steps = 0while True:steps += 1state_now = torch.tensor(cw.pos, dtype=torch.float)state_now = num_to_onehot(state_now).unsqueeze_(0).to(torch.float)action_values = Q(state_now)action_values = action_values.squeeze()action_now = Q.sample_action(action_value=action_values, epsilon=epsilon)action_now_value = action_values[action_now]  # 这个是采取这个动作的预测奖励state_next, reward_now, terminated, truncated, info = cw.step(action=action_now.item())   # 执行动作state_next = num_to_onehot(state_next).unsqueeze_(0).to(torch.float)with torch.no_grad():next_values = Q(state_next)next_values = next_values.squeeze()# 得到下一个的动作,(同一个策略下,因为这是onpolicy的sarsaaction_next = Q.sample_action(action_value=action_values, epsilon=epsilon)action_next_value = next_values[action_next]  # 计算下一个动作的预期价值# 计算  instantR + gamma * value_next,这个是实际上这个动作带来的预期收益discounted_reward = reward_now + gamma * action_next_value * (1 - terminated) * (1 - truncated)# 计算误差loss = loss_fn(action_now_value, discounted_reward)optimizer.zero_grad()loss.backward()optimizer.step()if terminated or truncated:if terminated:win_1000.append(0)if truncated:win_1000.append(1)total_win += 1breakepsilon = epsilon * epsilon_decayepsilon = max(epsilon, epsilon_min)  # 衰减学习旅win_1000 = win_1000[-1000:]win_rate = sum(win_1000)/1000.0print("{}/{}, 当前探索率: {}, 是否成功: {}, 千场胜率:{}.".format(i, nepisodes, epsilon, truncated, win_rate), flush=True)if i % 10000 == 0:Q.save_model("./out/ff_DQN/Qac_{}_{}_{}_{}.pth".format(i, gr, cr, win_rate))print("total win: ", total_win)# 收尾测试看看能不能通关path = np.zeros((4, 12), dtype=np.float64)cw.reset(random_reset=False)steps = 0while steps <= 48:  # 走,48步走不到头就不会走到了steps += 1state_now = torch.tensor(cw.pos, dtype=torch.float)state_now = num_to_onehot(state_now).unsqueeze_(0).to(torch.float)action_values = Q(state_now).squeeze()# 贪心算法选择动作action_now = Q.sample_action(action_values, 0)print(cw.pos[0], cw.pos[1], action_now)new_pos, _, die, win, _ = cw.step(action=action_now)if win:print("[+] you win!")breakif die:print("[+] you lose!")breakx = new_pos[0]y = new_pos[1]if x >= 0 and x <= 3 and y >= 0 and y <= 11:path[x, y] = 1.0plt.imshow(path)plt.colorbar()plt.savefig("./out/ff_DQN/path_sarsa_"+str(sr)+"_"+str(gr)+"_"+str(cr)+".png")if __name__ == "__main__":main()

上面的代码我测试没问题,如果不修改直接使用是完全可以的,目录结构是这样的:

那两个文件夹都是自动生成的,不需要手动建立。 

网络结构分析

这是上面代码的网络结构和更新流程。注意:实线代表有梯度,虚线代表无梯度。

每次由环境产生一个状态,先转成一个one_hot向量,作为网络的输入,得到四个动作分别价值多少。然后采样到的动作得到当前的Q(s, a)值,也就是action_value。

另一方面,采样得到的动作送入环境,环境给出下一个状态和立即奖励。下一个状态送入网络(没有梯度的计算),同样得到四个动作的价值。由于代码使用的是SARSA算法,所以需要按照同样的策略采样一个动作,同时得到动作的价值。也就是next_action_value。

这个时候,就可以根据环境的立即奖励reward_now和下一个状态的动作的价值next_action_value得到一个ground truth,而action_value作为网络的预测值,这两个可以用于计算损失。

损失的反向传播就是沿着实现传递到顶。实现网络的更新。

 


http://www.ppmy.cn/news/669564.html

相关文章

Linux操作系统配置代理服务器

PS:本文只是针对Linux操作系统对于代理服务器的配置操作&#xff0c;不涉及广告 1.代理的概念 代理服务器英文全称是Proxy Server&#xff0c;其功能就是代理网络用户去取得网络信息。形象的说&#xff1a;它是网络信息的中转站。在一般情况下&#xff0c;我们使用网络浏览器直…

2022最常用密码公布,你的账户安全吗?

密码管理工具 NordPass 公布了 2022 年最常用密码列表&#xff0c;以及破解密码所需的时间。该研究基于对来自 30 个不同国家 / 地区的 3TB 数据库的分析。研究人员将数据分为不同的垂直领域&#xff0c;使得其能够根据国家和性别进行统计分析。今年的研究主要聚焦于文化如何影…

Leetcode 647. 回文子串

Leetcode 647. 回文子串题目 给你一个字符串 s &#xff0c;请你统计并返回这个字符串中 回文子串 的数目。回文字符串 是正着读和倒过来读一样的字符串。子字符串 是字符串中的由连续字符组成的一个序列。具有不同开始位置或结束位置的子串&#xff0c;即使是由相同的字符组成…

NodeJS安装教程(详细)

系列文章 MySQL安装教程&#xff08;详细&#xff09; 本文链接&#xff1a;https://blog.csdn.net/youcheng_ge/article/details/126037520 MySQL卸载教程&#xff08;详细&#xff09; 本文链接&#xff1a;https://blog.csdn.net/youcheng_ge/article/details/129279265 …

Android 之保护用户隐私-禁止应用截屏或录频

引言 通常情况下&#xff0c;录屏、截图软件都可以在手机的运行过程中进行录屏、截图&#xff0c;但是在某些比较敏感的应用上&#xff0c;出于各种原因&#xff0c;会阻止录屏、截图软件进行运行。一旦录屏、截图软件被阻止运行就无法使用录屏以及截屏的功能。 使用 设置禁止…

3合1锂电便携式风扇IC

产品型号 输入 电压 最大 充电 电流 电机 类型 充电 截止 电压 精度 涓流 充电 截止 电压 功耗 封装 特点 HM5936 4.3 -5.5V 600mA 6V 4.2V 1% 2.9V 30uA SOP-16 带锂电保护,DC-DC升压限流,带充电指示及满电指示 带3LED指示锂电池电量,内置3档可调节风量控制…

微服务、SpringBoot、SpringCloud 三者的区别

&#x1f388; 作者&#xff1a;Linux猿 &#x1f388; 简介&#xff1a;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我&#xff0c;关注我&#xff0c;有问题私聊&#xff01; &…

hive表小文件合并

1. 背景 公司的 hive 表中的数据是通过 flink sql 程序&#xff0c;从 kafka 读取&#xff0c;然后写入 hive 的&#xff0c;为了数据能够被及时可读&#xff0c;我设置了 flink sql 程序的 checkpoint 时间为 1 分钟&#xff0c;因此&#xff0c;在 hive 表对应的 hdfs 上&am…