用神经网络自动玩游戏

news/2024/11/2 12:05:27/

神经网络玩游戏
CartPole是OpenAI gym中的一个游戏测试,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态.
 


游戏地址:https://gymnasium.farama.org/env ... _control/cart_pole/

一、搭建游戏运行环境
 

pip install swig
pip install gymnasium[box2d]



CartPole 环境内置在 gym 中,直接安装 gym 即可。其环境 id 是CartPole-v0 。Gym是一个研究和开发强化学习相关算法的仿真平台。简单来说OpenAI Gym提供了许多问题和环境(或游戏)的接口,而用户无需过多了解游戏的内部实现,通过简单地调用就可以用来测试和仿真。

pip install gym



启动游戏完整代码:

import gym# 创建 CartPole 环境
env = gym.make('CartPole-v1', render_mode="human")  # 使用新版本时,需要指定 render_mode# 重置环境,准备开始游戏
observation, info = env.reset()for _ in range(1000):# 随机选择一个动作(0或1)action = env.action_space.sample()# 应用动作到环境中,返回新的状态、奖励、完成标志和其他信息observation, reward, done, truncated, info = env.step(action)# 如果游戏结束或被截断,重置环境if done or truncated:observation, info = env.reset()# 关闭游戏
env.close()


可以看到游戏在自动运行。

二、手动控制
增加键盘操作:
 

# 获取键盘按键状态
keys = pygame.key.get_pressed()# 默认动作为随机生成 0 或 1,除非检测到按键输入
if not keys[K_LEFT] and not keys[K_RIGHT]:action = random.choice([0, 1])  # 随机选择 0 或 1# 根据按键改变动作
if keys[K_LEFT]:print("Left arrow pressed")action = 0  # 向左移动
elif keys[K_RIGHT]:print("Right arrow pressed")action = 1  # 向右移动




完整代码如下:

import gym
import pygame
from pygame.locals import K_LEFT, K_RIGHT, QUIT
import time
import random  # 用于生成随机数# 初始化 pygame
pygame.init()# 设置帧率
FPS = 60
clock = pygame.time.Clock()# 创建 CartPole 环境,使用 Gym 自带的渲染模式
env = gym.make('CartPole-v1', render_mode="human")# 重置环境,准备开始游戏
observation, info = env.reset()# 初始化游戏的运行标志
running = True# 操作次数
action_count = 0# 记录游戏开始的时间
start_time = time.time()# 游戏循环
while running:# 处理事件队列,检测关闭窗口操作for event in pygame.event.get():if event.type == QUIT:  # 退出事件running = Falsetime.sleep(0.3)# 获取键盘按键状态keys = pygame.key.get_pressed()# 默认动作为随机生成 0 或 1,除非检测到按键输入if not keys[K_LEFT] and not keys[K_RIGHT]:action = random.choice([0, 1])  # 随机选择 0 或 1# 根据按键改变动作if keys[K_LEFT]:print("Left arrow pressed")action = 0  # 向左移动elif keys[K_RIGHT]:print("Right arrow pressed")action = 1  # 向右移动# 执行动作并更新环境状态observation, reward, done, truncated, info = env.step(action)# 每执行一次动作,增加操作次数action_count += 1# 渲染环境,减少渲染频率(每隔 10 帧渲染一次)if action_count % 10 == 0:env.render()# 检查游戏是否结束if done or truncated:# 记录游戏结束的时间end_time = time.time()# 计算游戏持续的时间(秒)elapsed_time = end_time - start_time# 打印游戏信息print(f"游戏结束!总共操作了 {action_count} 次,持续时间为 {elapsed_time:.2f} 秒")# 重置环境observation, info = env.reset()# 重置计数器和时间action_count = 0start_time = time.time()# 控制帧率,确保游戏速度合适clock.tick(FPS)  # 设置帧率为 60 帧/秒# 关闭游戏环境和 pygame
env.close()
pygame.quit()

最好成绩成功控制了57次。
 




三、用电脑控制

#电脑自嗨
if observation[2] <= 0:action = 0
else:action = 1


电脑居然控制了59步。

四、用神经网络控制

1.训练神经网络
 

pip install torch
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque# 设置设备(如果有GPU则使用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 创建 CartPole 环境
env = gym.make('CartPole-v1')# 超参数
GAMMA = 0.99  # 折扣因子
EPSILON_START = 1.0  # 初始探索率
EPSILON_END = 0.01  # 最低探索率
EPSILON_DECAY = 0.995  # 探索率衰减
LEARNING_RATE = 0.001  # 学习率
MEMORY_SIZE = 10000  # 经验回放的容量
BATCH_SIZE = 64  # 批量大小
TARGET_UPDATE = 10  # 每隔多少步更新目标网络
TRAIN_TARGET_REWARD = 5000  # 训练停止目标:游戏持续达到 5000 步# DQN 网络结构
class DQN(nn.Module):def __init__(self, state_size, action_size):super(DQN, self).__init__()self.fc1 = nn.Linear(state_size, 24)self.fc2 = nn.Linear(24, 24)self.fc3 = nn.Linear(24, action_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 经验回放池
class ReplayMemory:def __init__(self, capacity):self.memory = deque(maxlen=capacity)def push(self, experience):self.memory.append(experience)def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)# 选择动作的epsilon-greedy策略
def select_action(state, policy_net, epsilon, n_actions):if random.random() < epsilon:return random.randrange(n_actions)  # 随机探索else:with torch.no_grad():state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)return policy_net(state).argmax(dim=1).item()  # 利用网络选择最优动作# 更新目标网络
def update_target_network(policy_net, target_net):target_net.load_state_dict(policy_net.state_dict())# 训练 DQN 网络
def optimize_model(policy_net, target_net, memory, optimizer):if len(memory) < BATCH_SIZE:return# 从经验回放池中采样experiences = memory.sample(BATCH_SIZE)# 将经验解包为不同部分states, actions, rewards, next_states, dones = zip(*experiences)states = torch.tensor(states, dtype=torch.float32).to(device)actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device)next_states = torch.tensor(next_states, dtype=torch.float32).to(device)dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)# 当前 Q 值q_values = policy_net(states).gather(1, actions)# 下一个状态的最大 Q 值(目标网络)with torch.no_grad():max_next_q_values = target_net(next_states).max(1)[0].unsqueeze(1)target_q_values = rewards + (GAMMA * max_next_q_values * (1 - dones))# 计算损失loss = nn.MSELoss()(q_values, target_q_values)# 反向传播优化optimizer.zero_grad()loss.backward()optimizer.step()# 初始化 DQN 网络和目标网络
n_actions = env.action_space.n
state_size = env.observation_space.shape[0]policy_net = DQN(state_size, n_actions).to(device)
target_net = DQN(state_size, n_actions).to(device)
update_target_network(policy_net, target_net)# 优化器
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)# 经验回放
memory = ReplayMemory(MEMORY_SIZE)# 训练主循环
epsilon = EPSILON_START
num_episodes = 1000for episode in range(num_episodes):state, _ = env.reset()done = Falsetotal_steps = 0total_reward = 0while not done:# 选择动作action = select_action(state, policy_net, epsilon, n_actions)# 执行动作next_state, reward, done, truncated, _ = env.step(action)total_steps += 1total_reward += reward# 如果游戏结束,则给负奖励reward = reward if not done else -10# 将经验存入回放池memory.push((state, action, reward, next_state, done))# 更新状态state = next_state# 训练模型optimize_model(policy_net, target_net, memory, optimizer)# 停止训练条件:游戏累计达到5000步,保存模型if total_steps >= TRAIN_TARGET_REWARD:print(f"Training completed after {episode+1} episodes, reaching {total_steps} steps.")torch.save(policy_net.state_dict(), 'dqn_cartpole.pth')break# 每隔一定步数更新目标网络if episode % TARGET_UPDATE == 0:update_target_network(policy_net, target_net)# epsilon衰减epsilon = max(EPSILON_END, EPSILON_DECAY * epsilon)print(f"Episode {episode + 1}/{num_episodes},total_steps:{total_steps}, Total Reward: {total_reward}")# 检查是否已经达到目标if total_steps >= TRAIN_TARGET_REWARD:breakenv.close()



保存模型的代码:

# 训练结束后保存模型
if total_steps >= TRAIN_TARGET_REWARD:print(f"Training completed after {episode+1} episodes, reaching {total_steps} steps.")torch.save(policy_net.state_dict(), 'dqn_cartpole.pth')break



模型保存成功,大小6kb

2.使用神经网络控制游戏:
调用神经网络决策关键代码:

    state_tensor = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(device)action = policy_net(state_tensor).argmax(dim=1).item()



使用模型并调用神经网络控制完整代码如下:
 

import gym
import pygame
import torch
import torch.nn as nn
from pygame.locals import K_LEFT, K_RIGHT, QUIT
import time
import random  # 用于生成随机数# 设置设备(如果有GPU则使用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化 pygame
pygame.init()# 设置帧率
FPS = 60
clock = pygame.time.Clock()# 创建 CartPole 环境,使用 Gym 自带的渲染模式
env = gym.make('CartPole-v1', render_mode="human")# 重置环境,准备开始游戏
observation, info = env.reset()# 初始化游戏的运行标志
running = True# 操作次数
action_count = 0# 记录游戏开始的时间
start_time = time.time()# DQN 网络结构
class DQN(nn.Module):def __init__(self, state_size, action_size):super(DQN, self).__init__()self.fc1 = nn.Linear(state_size, 24)self.fc2 = nn.Linear(24, 24)self.fc3 = nn.Linear(24, action_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 初始化 DQN 网络和目标网络
n_actions = env.action_space.n
state_size = env.observation_space.shape[0]# 加载模型并使用神经网络控制游戏
policy_net = DQN(state_size, n_actions).to(device)# 加载训练好的模型权重
MODEL_PATH = 'dqn_cartpole.pth'
policy_net.load_state_dict(torch.load(MODEL_PATH))# 设置为评估模式
policy_net.eval()# 游戏循环
while running:# 处理事件队列,检测关闭窗口操作for event in pygame.event.get():if event.type == QUIT:  # 退出事件running = Falsetime.sleep(0.3)# 获取键盘按键状态keys = pygame.key.get_pressed()# 默认动作为随机生成 0 或 1,除非检测到按键输入if not keys[K_LEFT] and not keys[K_RIGHT]:action = random.choice([0, 1])  # 随机选择 0 或 1# 根据按键改变动作if keys[K_LEFT]:print("Left arrow pressed")action = 0  # 向左移动elif keys[K_RIGHT]:print("Right arrow pressed")action = 1  # 向右移动state_tensor = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(device)action = policy_net(state_tensor).argmax(dim=1).item()# 执行动作并更新环境状态observation, reward, done, truncated, info = env.step(action)# 每执行一次动作,增加操作次数action_count += 1# 渲染环境,减少渲染频率(每隔 10 帧渲染一次)if action_count % 10 == 0:env.render()# 检查游戏是否结束if done or truncated:# 记录游戏结束的时间end_time = time.time()# 计算游戏持续的时间(秒)elapsed_time = end_time - start_time# 打印游戏信息print(f"游戏结束!总共操作了 {action_count} 次,持续时间为 {elapsed_time:.2f} 秒")# 重置环境observation, info = env.reset()# 重置计数器和时间action_count = 0start_time = time.time()# 控制帧率,确保游戏速度合适clock.tick(FPS)  # 设置帧率为 60 帧/秒# 关闭游戏环境和 pygame
env.close()
pygame.quit()

每次都能顺利通关
 


更详细:

神经网络自动玩游戏
https://www.jinshuangshi.com/forum.php?mod=viewthread&tid=353
(出处: 金双石科技)
 


http://www.ppmy.cn/news/1543859.html

相关文章

.net Core 使用Panda.DynamicWebApi动态构造路由

我们以前是通过创建controller来创建API&#xff0c;通过controller来显示的生成路由&#xff0c;这里我们讲解下如何不通过controller&#xff0c;构造API路由 安装 Panda.DynamicWebApi 1.2.2 1.2.2 Swashbuckle.AspNetCore 6.2.3 6.2.3添加ServiceAction…

平安养老险蚌埠中支开展敬老月金融知识宣传活动

近日&#xff0c;平安养老保险股份有限公司&#xff08;以下简称“平安养老险”&#xff09;蚌埠中心支公司走进蚌埠老年人户外活动中心——珠园&#xff0c;开展2024年“敬老月”金融知识宣传活动。本次活动以“坚持以老年人为中心&#xff0c;构建老年友好型社会”为主题&…

2024年双十一母婴专场 五款惊喜爆款产品大公开

在年度最盛大的购物狂欢节——双十一即将来临之际&#xff0c;我们特别为期待已久的各位父母与未来父母准备了一份惊喜。2024年的双十一母婴专场将带来五款经过精挑细选的爆款产品&#xff0c;每一款都是对宝宝成长路上的贴心呵护。从智能婴儿监护器到专业护眼台灯&#xff0c;…

【Gorm】传统sql的增删查改,通过go去操作sql

MySQL中的建库&#xff0c;建表&#xff0c;删库&#xff0c;删表&#xff0c;添加记录&#xff0c;查询&#xff0c;删除记录&#xff0c;更新记录这些命令是一定要回的&#xff0c;就算我们脱离 orm 这些&#xff0c;也能直接连接上数据库进行操作。 一、数据库的操作 # 查…

Php实现钉钉OA一级审批,二级审批

Php实现钉钉OA一级审批&#xff0c;二级审批 一级审批 public function oaPush($user_id,$person,$data){//测试数据&#xff0c;上线需要删除$user_id 154502333155;//发起人$person [154502665555];//审批人$len count($person);$result null;if($len>0){$approve_con…

开源趣味艺术画板Paint Board

什么是 Paint Board &#xff1f; Paint Board 是简洁易用的 Web 端创意画板。它集成了多种创意画笔和绘画功能&#xff0c;支持形状绘制、橡皮擦、自定义画板等操作&#xff0c;并可以将作品保存为图片。 软件功能&#xff1a; 不过非常可惜&#xff0c;老苏最期待的数据同步还…

躺平成长-下一个更新的数据(躺平成长数据显示核心)

旭日图&#xff08;Sunburst Chart&#xff09;是一种用于展示具有层次结构数据的可视化图表。 它起源于饼图和环形图&#xff0c;并随着数据可视化需求的发展而演变。 旭日图通过将层次结构数据以由内向外的同心圆环形式展示&#xff0c;使数据的层次关系更加清晰直观。 以下…

golang的RSA加密解密

参考&#xff1a;https://blog.csdn.net/lady_killer9/article/details/118026802 1.加密解密工具类PasswordUtil.go package utilimport ("crypto/rand""crypto/rsa""crypto/x509""encoding/pem""fmt""log"&qu…