强化学习玩flappy_bird

news/2024/9/23 9:26:15/

强化学习玩flappy_bird(代码解析)

游戏地址:https://flappybird.io/

该游戏的规则是:

  • 点击屏幕则小鸟立即获得向上速度。

  • 不点击屏幕则小鸟受重力加速度影响逐渐掉落。

  • 小鸟碰到地面会死亡,碰到水管会死亡。(碰到天花板不会死亡)

  • 小鸟通过水管会得分。

    img

    具体的网络结构如图所示,网络架构是拿到游戏状态(每个样本维度是 80 * 80 * 4),然后卷积(输出维度 20 * 20 * 32)、池化(输出 10 * 10 * 32)、卷积(输出 5 * 5 * 64)、卷积(输出 5 * 5 * 64)、reshape(1600)、全连接层(512)、输出层(2)

一、flappy_bird_utils.py

python">"""
游戏素材加载
"""
import pygame
import sys
import osassets_dir = os.path.dirname(__file__)def load():# 小鸟挥动翅膀的3个造型PLAYER_PATH = (assets_dir + '/assets/sprites/redbird-upflap.png',assets_dir + '/assets/sprites/redbird-midflap.png',assets_dir + '/assets/sprites/redbird-downflap.png')# 游戏背景图,纯黑色是为了训练降低干扰BACKGROUND_PATH = assets_dir + '/assets/sprites/background-black.png'# 水管图片PIPE_PATH = assets_dir + '/assets/sprites/pipe-green.png'IMAGES, SOUNDS, HITMASKS = {}, {}, {}#初始化了三个空字典:IMAGES用于存储加载的图片资源,SOUNDS用于存储加载的声音资源,HITMASKS用于存储碰撞掩码(用于检测游戏中的碰撞)# 加载数字0~9的图片,类型是Surface图像#使用convert_alpha()方法将图片转换为带有透明度的格式IMAGES['numbers'] = (pygame.image.load(assets_dir + '/assets/sprites/0.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/1.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/2.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/3.png').convert_alpha(),    # convert/conver_alpha是为了将图片转成绘制用的像素格式,提高绘制效率pygame.image.load(assets_dir + '/assets/sprites/4.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/5.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/6.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/7.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/8.png').convert_alpha(),pygame.image.load(assets_dir + '/assets/sprites/9.png').convert_alpha())# 地面图片IMAGES['base'] = pygame.image.load(assets_dir + '/assets/sprites/base.png').convert_alpha()#根据操作系统类型,设置声音文件的扩展名。Windows系统使用.wav,其他系统使用.oggif 'win' in sys.platform:soundExt = '.wav'else:soundExt = '.ogg'# 各种Sound对象#加载各种游戏音效,并将它们存储在SOUNDS字典中SOUNDS['die']    = pygame.mixer.Sound(assets_dir + '/assets/audio/die' + soundExt)SOUNDS['hit']    = pygame.mixer.Sound(assets_dir + '/assets/audio/hit' + soundExt)SOUNDS['point']  = pygame.mixer.Sound(assets_dir + '/assets/audio/point' + soundExt)SOUNDS['swoosh'] = pygame.mixer.Sound(assets_dir + '/assets/audio/swoosh' + soundExt)SOUNDS['wing']   = pygame.mixer.Sound(assets_dir + '/assets/audio/wing' + soundExt)# 加载背景图片IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()# 加载小鸟的3个姿态IMAGES['player'] = (pygame.image.load(PLAYER_PATH[0]).convert_alpha(),pygame.image.load(PLAYER_PATH[1]).convert_alpha(),pygame.image.load(PLAYER_PATH[2]).convert_alpha(),)# 加载水管图片,并使用rotate()方法将其旋转180度以创建上方的水管图片,然后将这两个图片存储在IMAGES字典中IMAGES['pipe'] = (pygame.transform.rotate(pygame.image.load(PIPE_PATH).convert_alpha(), 180),pygame.image.load(PIPE_PATH).convert_alpha(),)# 计算水管图片的bool掩码#为水管图片生成碰撞掩码,并将它们存储在HITMASKS字典中。HITMASKS['pipe'] = (getHitmask(IMAGES['pipe'][0]),getHitmask(IMAGES['pipe'][1]),)# 生成小鸟图片的bool掩码HITMASKS['player'] = (getHitmask(IMAGES['player'][0]),getHitmask(IMAGES['player'][1]),getHitmask(IMAGES['player'][2]),)return IMAGES, SOUNDS, HITMASKS# 生成图片的bool掩码矩阵,true表示对应像素位置不是透明的部分
def getHitmask(image):"""returns a hitmask using an image's alpha."""mask = []for x in range(image.get_width()):#遍历所有的像素点mask.append([])for y in range(image.get_height()):mask[x].append(bool(image.get_at((x,y))[3]))    # 像素点是RGBA,例如:(83, 56, 70, 255),最后是透明度(0是透明,255是不透明)#对于图像中的每一个像素点,使用 image.get_at((x,y)) 获取该点的颜色值。颜色值通常以 RGBA(红色、绿色、蓝色、透明度)格式存储,其中 A 代表 Alpha 通道,即透明度。image.get_at((x,y))[3] 就是获取该像素点的 Alpha 值。return mask

这里面需要解释的碰撞掩码是什么?

碰撞掩码(Collision Mask)是一种在计算机图形学和游戏开发中用于检测物体间碰撞的技术。它通常由一个布尔矩阵表示,其中每个像素点的值表示该点是否是物体的一部分。在处理碰撞检测时,通过比较两个物体的碰撞掩码可以判断它们是否重叠,从而确定是否发生了碰撞。

以下是碰撞掩码的一些关键点:

  1. 透明度判断:在许多游戏中,碰撞掩码是通过检查图像的透明度(Alpha通道)来生成的。如果图像的某个像素点是不透明的(例如,Alpha值为255),那么在碰撞掩码中对应的位置会被标记为True或实体部分;如果是透明的(Alpha值为0),则被标记为False或非实体部分。

  2. 简化碰撞检测:使用碰撞掩码可以避免直接对图像的每个像素点进行碰撞检测,这样可以显著提高碰撞检测的效率,尤其是在处理复杂图形或大规模场景时。

  3. 灵活性:碰撞掩码可以根据需要设计成不同的形状和大小,从而实现精确的碰撞检测。例如,一个角色的碰撞掩码可以是其轮廓的形状,而不仅仅是一个矩形或正方形。

  4. 性能优化:在游戏开发中,碰撞检测通常是一个计算密集型的过程。通过使用碰撞掩码,可以减少不必要的像素点比较,从而提高游戏性能。

  5. 应用场景:碰撞掩码不仅用于检测角色与障碍物之间的碰撞,还可以用于检测子弹与目标的碰撞、角色间的交互等。

getHitmask函数就是用来生成碰撞掩码的。它通过遍历图像的每个像素点,并检查其透明度来创建一个布尔矩阵。这个矩阵随后可以用于游戏中的碰撞检测逻辑,以判断小鸟是否与水管或其他物体发生了碰撞。

二、wrapped_flappy_bird.py

python">import numpy as np
import sys
import random
import pygame
from . import flappy_bird_utils
import pygame.surfarray as surfarray#用于将pygame的Surface对象转换为NumPy数组
from pygame.locals import *
from itertools import cycle#用于创建一个可循环的对象# 屏幕宽*高
FPS = 30
SCREENWIDTH  = 288
SCREENHEIGHT = 512# 初始化游戏,创建一个时钟对象来控制帧率,设置游戏窗口的尺寸和标题
pygame.init()
FPSCLOCK = pygame.time.Clock()  # FPS限速器
SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))   # 宽*高
pygame.display.set_caption('Flappy Bird')   # 标题# 加载素材
IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()PIPEGAPSIZE = 100 # 上下水管之间的距离是固定的100像素
BASEY = SCREENHEIGHT * 0.79 # 地面图片的y坐标'''
地面图片在游戏窗口中的垂直位置。SCREENHEIGHT是游戏窗口的高度,
乘以0.79后得到一个值,这个值就是地面图片距离窗口顶部的像素距离。
因此,BASEY变量代表了地面图片在Y轴(垂直轴)上的位置,
它被设置在屏幕高度的79%的位置,这样地面图片会显示在屏幕的下半部分。
'''# 小鸟图片的宽*高
PLAYER_WIDTH = IMAGES['player'][0].get_width()
PLAYER_HEIGHT = IMAGES['player'][0].get_height()
# 水管图片的宽*高
PIPE_WIDTH = IMAGES['pipe'][0].get_width()
PIPE_HEIGHT = IMAGES['pipe'][0].get_height()# 背景图片的宽
BACKGROUND_WIDTH = IMAGES['background'].get_width()# 创建一个循环对象,小鸟图片动画播放顺序
PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])'''
0 表示第一张图片(翅膀上挥)
1 表示第二张图片(翅膀中挥)
2 表示第三张图片(翅膀下挥)
序列最后再次包含1,以实现翅膀的自然循环
'''# Flappy bird游戏类
class GameState:def __init__(self):self.score = 0#初始化玩家的得分为 0self.playerIndex = 0#初始化玩家小鸟的当前动画索引为 0,这将决定小鸟显示哪一张动画图片self.loopIter = 0#初始化一个循环计数器,可能用于跟踪动画或游戏循环的次数# 玩家初始坐标self.playerx = int(SCREENWIDTH * 0.2)#设置玩家小鸟的初始 x 坐标,位于屏幕宽度的 20% 位置self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)#计算并设置玩家小鸟的初始 y 坐标,使得小鸟位于屏幕垂直居中的位置# 地面图片需要跑马灯效果,它比屏幕宽一点,每帧向左移动,当要耗尽时重新回到右边,如此往复self.basex = 0         # 地面图片的x坐标self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH  # 地面图片比屏幕宽度长多少像素,就是它可以移动的距离newPipe1 = getRandomPipe()  # 生成一对上下管子newPipe2 = getRandomPipe()  # 再生成一对上下管子# 上面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离self.upperPipes = [{'x': SCREENWIDTH, 'y': newPipe1[0]['y']},{'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},]# 下面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离self.lowerPipes = [{'x': SCREENWIDTH, 'y': newPipe1[1]['y']},{'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},]# 水管的水平移动速度,每次x-4实现向左移动self.pipeVelX = -4# 小鸟Y方向速度self.playerVelY    =  0# 小鸟Y方向重力加速度,每帧作用域playerVelY,令其Y速度向下加大self.playerAccY    =   1# 点击后,小鸟Y方向速度重置为-9,也就是开始向上移动self.playerFlapAcc =  -9# 小鸟Y方向速度限制self.playerMaxVelY =  10   # Y向下最大速度10# 执行一次操作,返回操作后的画面、本次操作的奖励(活着+0.1,死了-1,飞过水管+1)、游戏是否结束def frame_step(self, input_actions):# 给pygame对积累的事件做一下默认处理pygame.event.pump()# 活着就奖励0.1分reward = 0.01# 是否死了terminal = False# 必须传有效的action,[1,0]表示不点击,[0,1]表示点击,全传0是不对的if sum(input_actions) != 1:#检查 input_actions 确保只有一个动作被执行raise ValueError('Multiple input actions!')# 每3帧换一次小鸟造型图片,loopIter统计经过了多少帧if (self.loopIter + 1) % 3 == 0:self.playerIndex = next(PLAYER_INDEX_GEN)self.loopIter += 1# 让地面向左移动,游戏开始的时候地面x=0,逐步减小xif self.basex + self.pipeVelX <= -self.baseShift:self.basex = 0else: # 图片即将滚动耗尽,重置x坐标self.basex += self.pipeVelX# 点击了屏幕if input_actions[1] == 1:self.playerVelY = self.playerFlapAcc # 将小鸟y方向速度重置为-9,也就是向上移动#SOUNDS['wing'].play()   # 播放扇翅膀的声音elif self.playerVelY < self.playerMaxVelY:  # 没点击屏幕并且没达到最大掉落速度,继续施加重力加速度self.playerVelY += self.playerAccY# 将速度施加到小鸟的y坐标上self.playery += self.playerVelYif self.playery < 0:    # 撞到上边缘不算死self.playery = 0 # 限制它别飞出去elif self.playery + PLAYER_HEIGHT >= BASEY: # 小鸟碰到地面self.playery = BASEY - PLAYER_HEIGHT # 限制它别穿地# 让上下水管都向左移动一次for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):uPipe['x'] += self.pipeVelXlPipe['x'] += self.pipeVelX# 判断小鸟是否穿过了一排水管,因为上下水管x一样,只需要用上排水管判断playerMidPos = self.playerx + PLAYER_WIDTH / 2  # 小鸟中心的x坐标(这个是固定值,小鸟实际不会动,是水管在动)for pipe in self.upperPipes:    # 检查与上排水管的关系pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 # 水管中心的x坐标if pipeMidPos <= playerMidPos < pipeMidPos + abs(self.pipeVelX): # 小鸟x坐标刚刚飞过了水管x中心(4是水管的移动速度)self.score += 1 # 游戏得分+1#SOUNDS['point'].play()reward = 100  # 产生强化学习的动作奖励10分# 最左侧水管马上离开屏幕,生成新水管if 0 < self.upperPipes[0]['x'] < 5:newPipe = getRandomPipe()self.upperPipes.append(newPipe[0])self.lowerPipes.append(newPipe[1])# 最左侧水管彻底离开屏幕,删除它的上下2根水管if self.upperPipes[0]['x'] < -PIPE_WIDTH:self.upperPipes.pop(0)self.lowerPipes.pop(0)# 检查小鸟是否碰到水管isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 'index': self.playerIndex}, self.upperPipes, self.lowerPipes)if isCrash:  # 死掉了#SOUNDS['hit'].play()#SOUNDS['die'].play()reward = -10 # 负向激励分terminal = True # 本次操作导致游戏结束了##### 进入重绘 ######## 贴背景图SCREEN.blit(IMAGES['background'], (0,0))# 画水管for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))# 画地面SCREEN.blit(IMAGES['base'], (self.basex, BASEY))# 画得分(训练时候别打开,造成干扰了)#showScore(self.score)# 画小鸟SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery))# 重绘pygame.display.update()# 留存游戏画面(截图是列优先存储的,需要转行行优先存储)# https://stackoverflow.com/questions/34673424/how-to-get-numpy-array-of-rgb-colors-from-pygame-surfaceimage_data = pygame.surfarray.array3d(pygame.display.get_surface()).swapaxes(0,1)# 死亡则重置游戏状态if terminal:self.__init__()# 控制FPSFPSCLOCK.tick(FPS)return image_data, reward, terminal# 生成一对水管,放到屏幕外面
def getRandomPipe():gapY = random.randint(70, 140)#生成一个介于 70 到 140 之间的随机整数,并将其赋值给变量 gapY。这个随机数决定了水管之间缝隙的上边缘的 y 坐标# 注:每一对水管的缝隙高度都是一样的PIPEGAPSIZE,gayY决定的是缝隙的上边缘y坐标pipeX = SCREENWIDTH + 10    # 水管出现在屏幕右侧之外return [{'x': pipeX, 'y': gapY - PIPE_HEIGHT},  # 计算上面水管图片的y坐标,就是缝隙上边缘y减去水管本身高度{'x': pipeX, 'y': gapY + PIPEGAPSIZE},  # 计算下面水管图片的y坐标,就是缝隙上边缘y加上缝隙本身高度]# 检查小鸟是否碰到水管或者地面(天花板不算)
def checkCrash(player, upperPipes, lowerPipes):pi = player['index']    # 小鸟的第几张图片# 图片的宽*高player['w'] = IMAGES['player'][pi].get_width()player['h'] = IMAGES['player'][pi].get_height()# 小鸟碰到了地面if player['y'] + player['h'] >= BASEY - 1:return Trueelse: # 小鸟与水管进行碰撞检测# 小鸟图片的矩形区域playerRect = pygame.Rect(player['x'], player['y'], player['w'], player['h'])# 每一对水管for uPipe, lPipe in zip(upperPipes, lowerPipes):# 上面水管的矩形uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)# 下面水管的矩形lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)# 小鸟图片的非透明像素掩码pHitMask = HITMASKS['player'][pi]# 上水管的非透明像素掩码uHitmask = HITMASKS['pipe'][0]# 下水管的非透明像素掩码lHitmask = HITMASKS['pipe'][1]# 检测小鸟与上面水管的碰撞uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)# 检测小鸟与下面水管的碰撞lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)if uCollide or lCollide:return Truereturn False# 2个矩形区域的碰撞检测
def pixelCollision(rect1, rect2, hitmask1, hitmask2):'''rect1 和 rect2 是参与碰撞检测的两个矩形区域,通常是游戏中对象的位置和大小hitmask1 和 hitmask2 是与这两个矩形关联的碰撞掩码,它们是布尔数组,表示相应对象的哪些部分是实体(非透明)'''# 计算两个矩形的交集,即它们重叠的区域。如果没有重叠(即两个矩形没有碰撞),则 clip 方法返回一个宽度或高度为 0 的矩形rect = rect1.clip(rect2)# 相交面积为0if rect.width == 0 or rect.height == 0:return False# 相交矩形x,y相对于2个矩形左上角的距离x1, y1 = rect.x - rect1.x, rect.y - rect1.y#计算交集区域相对于 rect1 的相对位置x2, y2 = rect.x - rect2.x, rect.y - rect2.y#同理# 检查相交矩形内的每个点,是否在2个矩形内同时是非透明点,那么就碰撞了for x in range(rect.width):for y in range(rect.height):if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:return Truereturn False# 展示得分,传入一个整数得分
def showScore(score):# 转成单个数字的列表scoreDigits = [int(x) for x in list(str(score))]#将得分 score 转换成字符串,然后将其每个字符(即每个单独的数字)转换成整数,并存储在列表 scoreDigits 中。这样,得分就被分解成了单个数字的列表# 计算展示所有数字要占多少像素宽度totalWidth = 0for digit in scoreDigits:totalWidth += IMAGES['numbers'][digit].get_width()'''遍历 scoreDigits 列表中的每个数字将每个数字图像的宽度累加到 totalWidth'''# 计算绘制起始x坐标Xoffset = (SCREENWIDTH - totalWidth) / 2# 逐个数字绘制for digit in scoreDigits:SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, 20))    # y坐标贴近屏幕上边缘Xoffset += IMAGES['numbers'][digit].get_width() # 移动绘制x坐标

三、q_game.py

python">"""
强化学习q learning flappy bird
"""
from game.wrapped_flappy_bird import GameState
import time
import numpy as np 
import skimage.color
import skimage.transform
import skimage.exposure
import tensorflow as tf 
import random 
import argparse# 命令行参数
parser = argparse.ArgumentParser()#创建一个 ArgumentParser 对象,用于定义命令行参数
parser.add_argument("--model-only", help="加载已有模型,不随机探索,仍旧训练", action='store_true')
args = parser.parse_args()#解析命令行输入的参数,并将它们存储在 args 变量中# 测试用代码
def _test_save_img(img):# 把每一帧图片存储到文件里,调试用from PIL import Imageim = Image.fromarray((img*255).astype(np.uint8), mode='L') # 图片已经被处理为0~1之间的亮度值,所以*255取整数变灰度展示im.save('./img.jpg')# 构建卷积神经网络
def build_model():# 卷积神经网络:https://blog.csdn.net/FontThrone/article/details/76652753model = tf.keras.models.Sequential([#创建一个 Sequential 模型,它是 tf.keras 中用于线性堆叠网络层的模型类tf.keras.layers.Input(shape=(80,80,4)),tf.keras.layers.Conv2D(filters=32, kernel_size=(8, 8), padding='same',strides=4, activation='relu'),tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),tf.keras.layers.Conv2D(filters=64, kernel_size=(4, 4), padding='same',strides=2, activation='relu'),tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same',strides=1, activation='relu'),tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),tf.keras.layers.Flatten(),#将三维的卷积层输出展平为一维,以便传入到全连接层tf.keras.layers.Dense(256, activation='relu'),#定义一个具有 256 个单元的全连接层,并使用 ReLU 激活函数tf.keras.layers.Dense(2), # 对应2个action未来总回报预期])model.compile(loss='mse', optimizer='adam')#编译模型,指定均方误差(MSE)作为损失函数,使用 Adam 优化器# 尝试加载之前保存的模型参数try:model.load_weights('./weights.h5')print('加载模型成功...................')except:passreturn model# 创建游戏
game = GameState()
# 卷积模型
model = build_model()# 执行1帧游戏
def run_one_frame(action):global game # image_data:执行动作后的图像(288*512*3的RGB三维数组)# reward:本次动作的奖励# terminal:游戏是否失败img, reward, terminal = game.frame_step(action)# RGB转灰度图img = skimage.color.rgb2gray(img)# 压缩到80*80的图片(根据RGB算出来的亮度,其数值很小)img = skimage.transform.resize(img, (80,80))# 把亮度标准化到0~1之间,用作模型输入img = skimage.exposure.rescale_intensity(img, out_range=(0,1))return img,reward,terminal# 强化学习初始化状态
def reset_stat():# 执行第一帧,不点击img_t,_,_ =  run_one_frame([1,0])'''使用 numpy.stack 函数将首帧图像 img_t 重复四次,沿着第三个维度堆叠,形成初始状态 stat_t。这是因为卷积神经网络需要连续几帧的图像作为输入'''stat_t = np.stack([img_t] * 4, axis=2)return stat_t # 初始状态
stat_t = reset_stat()
# 训练样本
transitions = []#用于存储训练过程中的状态转换样本# 时刻
t = 0# 随机探索的概率控制,定义了随机探索概率的初始值、最终值和每次更新的步长。
INIT_EPSILON = 0.1
FINAL_EPSILON = 0.005
EPSLION_DELTA = 1e-6
# 最大留存样本个数
TRANS_CAP =  20000
# 至少有多少样本才训练
TRANS_SIZE_FIT = 10000
# 训练集大小
BATCH_SIZE = 32
# 未来激励折扣
GAMMA = 0.99# 随机探索概率
if args.model_only: # 不随机探索(极低概率)epsilon = FINAL_EPSILON
else:epsilon = INIT_EPSILON# 打印一些进度信息
rand_flap =0    # 随机点击次数
rand_noflap = 0 # 随机不点击次数
model_flap=0    # 模型点击次数
model_noflap=0  # 模型不点击次数
model_train_times = 0   # 模型训练次数# 游戏启动
while True:    # 动作action_t = [0,0]action_type = '随机'#设置动作类型默认为 '随机',这将在选择动作时用于判断动作是随机选择的还是基于模型经验选择的。# 随着学习,降低随机探索的概率,让模型趋于稳定if (t <= TRANS_SIZE_FIT and not args.model_only) or random.random() <= epsilon:'''判断是否应该进行随机探索。如果在观察期内(t <= TRANS_SIZE_FIT)或者随机数小于或等于 epsilon,则执行随机探索'''n = random.random()if n <= 0.95:action_index = 0rand_noflap+=1else:action_index = 1rand_flap+=1#print('[随机探索] t时刻进行随机动作探索...')else: # 模型预测2个操作的未来累计回报action_type = '经验'Q_t = model.predict(np.expand_dims(stat_t, axis=0))[0]#使用当前的模型和状态 stat_t 来预测两个动作的未来总回报action_index = np.argmax(Q_t)   # 回报最大的action下标if action_index==0:model_noflap+=1else:model_flap+=1#print('[已有经验] 预测t时刻2个动作的未来总回报 -- 不点击:{} 点击:{}'.format(Q_t[0], Q_t[1]))action_t[action_index] = 1#print('时刻t将执行的动作为{}'.format(action_t))# 执行当前动作,返回操作后的图片、本次激励、游戏是否结束img_t1, reward, terminal = run_one_frame(action_t)_test_save_img(img_t1)img_t1 = img_t1.reshape((80,80,1)) # 增加通道维度,因为我们要最近4帧作为4通道图片,用作卷积模型输入stat_t1 = np.append(stat_t[:,:,1:], img_t1, axis=2) # 80*80*4,淘汰当前的第0通道,添加最新t1时刻到第3通道# 收集训练样本(保留有限的)transitions.append({'stat_t': stat_t,   # t时刻状态'stat_t1': stat_t1, # t1时刻状态'reward': reward,   # 本次动作的激励得分'terminal': terminal,   # 执行动作后游戏是否结束(ps: 结束意味着没有未来激励了)'action_index': action_index,   # 执行了什么动作(0:不点击,1:点击)})if len(transitions) > TRANS_CAP:transitions.pop(0)# 游戏结束则重置stat_tif terminal:stat_t = reset_stat()#print('死了!!!!!!! 状态t重置为初始帧...')else:   # 否则切为新的状态stat_t = stat_t1#print('没死~~~ 状态t切换为状态t1...')# 过了观察期,开始训练if t >= TRANS_SIZE_FIT and t % 10 == 0:minibatch = random.sample(transitions, BATCH_SIZE)# 模型训练的输入:t时刻的状态(最近4帧图片)inputs_t = np.concatenate([tran['stat_t'].reshape((1,80,80,4)) for tran in minibatch])#print('inputs_t shape', inputs_t.shape)####################################################### 模型训练的输出:t时刻的未来总激励(Q_t = reward+gamma*Q_t1)# 1,让模型预测t时刻2种action的未来总激励Q_t = model.predict(inputs_t, batch_size=len(minibatch))# 2,让模型预测t1时刻2种action的未来总激励input_t1 = np.concatenate([tran['stat_t1'].reshape((1,80,80,4)) for tran in minibatch])Q_t1 = model.predict(input_t1, batch_size=len(minibatch))# 3,保留t1时刻2个action中最大的未来总激励Q_t1_max = [max(q) for q in Q_t1]# 4,t时刻进行action_index动作得到真实激励reward_t = [tran['reward'] for tran in minibatch]# 5,t时刻进行了什么actionaction_index_t = [tran['action_index'] for tran in minibatch]# 6,t1时刻是否死掉了terminal = [tran['terminal'] for tran in minibatch]# 7,修正训练的目标Q_t=reward+gamma*Q_t1# (t时刻action_index的未来总激励=action_index真实激励+t1时刻预测的最大未来总激励)for i in range(len(minibatch)):if terminal[i]:Q_t[i][action_index_t[i]] = reward_t[i] # 因为t1时刻已经死了,所以没有t1之后的累计激励else:Q_t[i][action_index_t[i]] = reward_t[i] + GAMMA*Q_t1_max[i] # Q_t=reward+Q_t1# print('Q_t shape', Q_t.shape)# 训练一波#print(inputs_t)#print(Q_t)model.fit(inputs_t, Q_t, batch_size=len(minibatch))model_train_times += 1# 训练1次则降低些许的随机探索概率if epsilon > FINAL_EPSILON:epsilon -= EPSLION_DELTA# 每5000次batch保存一次模型权重(不适用saved_model,后续加载只会加载权重,模型结构还是程序构造,因为这样可以保持keras model的api)if model_train_times % 5000 == 0:model.save_weights('./weights.h5')######################################################if t % 100 == 0:print('总帧数:{} 剩余探索概率:{}% 累计训练次数:{} 累计随机点:{} 累计随机不点:{} 累计模型点:{} 累计模型不点:{} 训练集:{} '.format(t, round(epsilon * 100, 4), model_train_times, rand_flap, rand_noflap, model_flap, model_noflap,len(transitions)))t = t + 1#time.sleep(1)

四、text_game.py

python">"""
演示pygame制作的flappy bird如何逐帧调用执行
"""
from game.wrapped_flappy_bird import GameState
from random import random
import time# 创建游戏
game = GameState()# 游戏启动
while True:r = random()if r <= 0.92:  # 92%的概率不点击屏幕game.frame_step([1,0]) # 动作:[1,0] 表示不点击else: # 8%的概率点击屏幕game.frame_step([0,1]) # 动作:[0,1] 表示点击

五、训练结果

请添加图片描述

代码源自:强化学习Deep Q-Network自动玩flappy bird | 鱼儿的博客 (yuerblog.cc)

仅想具体看一下工作原理和代码,仅供学习使用


http://www.ppmy.cn/news/1455269.html

相关文章

电脑数据怎么拷贝到u盘?操作指南与数据丢失防范

在数字时代&#xff0c;数据的传输与备份已成为我们日常生活和工作中不可或缺的一部分。U盘作为一种便捷、高效的移动存储设备&#xff0c;广泛应用于各种数据拷贝场景。无论是个人文件的备份&#xff0c;还是工作资料的传输&#xff0c;U盘都发挥着举足轻重的作用。那么&#…

状压dp 理论例题 详解

状压dp 四川2005年省选题&#xff1a;互不侵犯 首先我们可以分析一下&#xff0c;按照我们普通的思路&#xff0c;就是用搜索&#xff0c;枚举每一行的每一列&#xff0c;尝试放下一个国王&#xff0c;然后标记&#xff0c;继续枚举下一行 那么&#xff0c;我们的时间复杂度…

LNMP一键安装包

LNMP一键安装包是什么? LNMP一键安装包是一个用Linux Shell编写的可以为CentOS/RHEL/Fedora/Debian/Ubuntu/Raspbian/Deepin/Alibaba/Amazon/Mint/Oracle/Rocky/Alma/Kali/UOS/银河麒麟/openEuler/Anolis OS Linux VPS或独立主机安装LNMP(Nginx/MySQL/PHP)、LNMPA(Nginx/MySQ…

EasyExcel 处理 Excel

序言 本文介绍在日常的开发中&#xff0c;如何使用 EasyExcel 高效处理 Excel。 一、EasyExcel 是什么 EasyExcel 是阿里巴巴开源的一个 Java Excel 操作类库&#xff0c;它基于 Apache POI 封装了简单易用的 API&#xff0c;使得我们能够方便地读取、写入 Excel 文件。Easy…

Hexview工具使用说明

一般Davinci工具都会在Misc路径下面配一个hexview工具。Hexview工具是免安装的&#xff0c;功能非常强大&#xff0c;可以打开并解析hex文件和srec文件&#xff0c;哪怕这两种文件格式不一样&#xff0c;解析出来的结果是一样的。 文件描述 _examples是例子 _expdatproc是用…

VUE3从入门到精通

第一章> 1、前端工程化是什么 2、webpack的作用 3、plugin的基本使用 4、loader的基本使用 5、SourceMap的作用 第二章> 1、VUE基本使用步骤 2、各种指令的使用 3、过滤器 4、实际案例 第三章> 1、单页面应用与组件化开发 2、vue三个组成部分…

Visual Studio C++ 的一个简单示例

Visual Studio 项目属性设置&#xff1a; 项目属性→C/C→常规→附加包含目录 C:\Intel\include\iconv\include;项目属性→链接器→常规→附加库目录 C:\Intel\include\iconv\lib;项目属性→链接器→输入→附加依赖项 iconv.lib;提示缺少"iconv.dll"&#xff0c;…

SpringBoot优雅地定制JSON响应数据

提示&#xff1a;文章若有错误&#xff0c;欢迎评论区指正&#x1f36d; 文章目录 前言 一、如何使用JsonView这个注解&#xff1f; 二、应用场景 三、实战案例 注解方式 编程方式 总结 前言 最近在学习过程中发现了Jackson库的JsonView也可以改变JSON的输出结构&#xff0c;…