pytorch训练五子棋ai

news/2025/2/21 11:15:52/

有3个文件

game.py  五子棋游戏

mod.py  神经网络模型

xl.py   训练的代码

aigame.py   玩家与对战的五子棋

game.py

python"> 
class Game:def __init__(self, h, w):# 行数self.h = h# 列数self.w = w# 棋盘self.L = [['-' for _ in range(w)] for _ in range(h)]# 当前玩家 - 表示空 X先下 然后是Oself.cur = 'X'# 游戏胜利者self.win_user = None# 检查下完这步后有没有赢 y是行 x是列 返回True表示赢def check_win(self, y, x):directions = [# 水平、垂直、两个对角线方向(1, 0), (0, 1), (1, 1), (1, -1)]player = self.L[y][x]for dy, dx in directions:count = 0# 检查四个方向上的连续相同棋子for i in range(-4, 5):  # 检查-4到4的范围,因为五子连珠需要5个棋子ny, nx = y + i * dy, x + i * dxif 0 <= ny < self.h and 0 <= nx < self.w and self.L[ny][nx] == player:count += 1if count == 5:return Trueelse:count = 0return False# 检查能不能下这里 y行 x列 返回True表示能下def check(self, y, x):return self.L[y][x] == '-' and self.win_user is None# 打印棋盘 可视化用得到def __str__(self):# 确定行号和列号的宽度row_width = len(str(self.h - 1))col_width = len(str(self.w - 1))# 生成带有行号和列号的棋盘字符串表示result = []# 添加列号标题result.append(' ' * (row_width + 1) + ' '.join(f'{i:>{col_width}}' for i in range(self.w)))# 添加分隔线(可选)result.append(' ' * (row_width + 1) + '-' * (col_width * self.w))# 添加棋盘行for y, row in enumerate(self.L):# 添加行号result.append(f'{y:>{row_width}} ' + ' '.join(f'{cell:>{col_width}}' for cell in row))return '\n'.join(result)# 一步棋def set(self, y, x):if self.win_user or not self.check(y, x):return Falseself.L[y][x] = self.curif self.check_win(y, x):self.win_user = self.curreturn Trueself.cur = 'X' if self.cur == 'O' else 'O'return True#和棋def heqi(self):for y in range(self.h):for x in range(self.w):if self.L[y][x]=='-':return Falsereturn True#玩家自己下
def run_game01():g = Game(15, 15)while not g.win_user:# 打印当前棋盘状态while 1:print(g)try:y,x=input(g.cur+':').split(',')x=int(x)y=int(y)if g.set(y,x):breakexcept Exception as e:print(e)print(g)print('胜利者',g.win_user)

mod.py

python">import torch
import torch.nn as nn
import torch.optim as optim
from game import Gameclass MyMod(nn.Module):def __init__(self, input_channels=1, output_size=15*15):super(MyMod, self).__init__()# 定义卷积层,用于提取特征self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)  # 输出 32 x 15 x 15self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 输出 64 x 15 x 15self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  # 输出 128 x 15 x 15# 定义全连接层,用于最后的得分预测self.fc1 = nn.Linear(128 * 15 * 15, 1024)  # 展平后传入全连接层self.fc2 = nn.Linear(1024, output_size)  # 输出 15*15 的得分预测def forward(self, x):# 卷积层 -> 激活函数 -> 最大池化x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))x = torch.relu(self.conv3(x))# 将卷积层输出展平为一维x = x.view(x.size(0), -1)# 全连接层x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 保存模型权重def save(self, path):torch.save(self.state_dict(), path)# 加载模型权重def load(self, path):self.load_state_dict(torch.load(path))#改进一下  output 把有棋子的地方的概率=0避免下这些地方
# 输入Game对象和MyMod对象,用于得到概率最大的落棋点 (行y, 列x)
def input_qi(g: Game, m: MyMod):# 获取当前棋盘状态board_state = g.L  # 使用 game.L 获取当前棋盘的状态 (15x15的二维列表)# 将棋盘状态转换为PyTorch的Tensor并增加一个维度(batch_size = 1)board_tensor = torch.tensor([[1 if cell == 'X' else -1 if cell == 'O' else 0 for cell in row] for row in board_state], dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # 形状变为 (1, 1, 15, 15)# 传入模型获取每个位置的得分output = m(board_tensor)# 将输出转为概率值(可以使用softmax来归一化)probabilities = torch.softmax(output, dim=-1).view(g.h, g.w).detach().numpy()  # 变为 (15, 15) 大小# 将已有棋子的位置的概率设置为 -inf,避免选择这些位置for y in range(g.h):for x in range(g.w):if board_state[y][x] != '-':probabilities[y, x] = -float('inf')  # 设置已经有棋子的地方的概率为 -inf# 找到概率最大的落子点max_prob_pos = divmod(probabilities.argmax(), g.w)  # 得到最大概率的行列坐标# 确保返回的是合法的位置y, x = max_prob_posreturn (y, x), output  # 返回坐标和模型输出

xl.py

python">import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from mod import MyMod, input_qi, Game# 两个权重文件,分别代表 X 棋和 O 棋
MX = 'MX'
MO = 'MO'# 加载模型,若文件不存在则初始化
def load_model(model, path):if os.path.exists(path):model.load(path)print(f"Loaded model from {path}")else:print(f"{path} not found, initializing new model.")# 这里可以加一些初始化模型的代码,例如:# model.apply(init_weights) 如果需要初始化权重# 初始化模型
modx = MyMod()
load_model(modx, MX)modo = MyMod()
load_model(modo, MO)# 定义优化器
lr=0.001
optimizer_x = optim.Adam(modx.parameters(), lr=lr)
optimizer_o = optim.Adam(modo.parameters(), lr=lr)# 损失函数:根据游戏结果调整损失
def compute_loss(winner: int, player: str, model_output):# 将目标值转换为相应的张量if player == "X":if winner == 1:  # X 胜target = torch.tensor(1.0, dtype=torch.float32)elif winner == 0:  # 平局target = torch.tensor(0.5, dtype=torch.float32)else:  # X 输target = torch.tensor(0.0, dtype=torch.float32)else:if winner == -1:  # O 胜target = torch.tensor(1.0, dtype=torch.float32)elif winner == 0:  # 平局target = torch.tensor(0.5, dtype=torch.float32)else:  # O 输target = torch.tensor(0.0, dtype=torch.float32)# 确保目标值的形状和 model_output 一致,假设 model_output 是单一的值target = target.unsqueeze(0).unsqueeze(0)  # 形状变为 (1, 1)# 使用均方误差损失计算return F.mse_loss(model_output, target)# 训练模型的过程
def train_game():modx.train()modo.train()# 创建新的游戏实例game = Game(15, 15)  # 默认是 15x15 棋盘# 反向传播和优化optimizer_x.zero_grad()optimizer_o.zero_grad()while not game.win_user:  # 游戏未结束# X 方落子x_move, x_output = input_qi(game, modx)  # 获取落子位置和模型输出(x_output 是模型的输出)game.set(x_move[0], x_move[1])  # X 下棋if game.win_user:break# O 方落子o_move, o_output = input_qi(game, modo)  # 获取落子位置和模型输出(o_output 是模型的输出)#print(o_move,game)game.set(o_move[0], o_move[1])  # O 下棋# 获取比赛结果winner = 0 if game.heqi() else (1 if game.win_user == 'X' else -1)  # 1为X胜,-1为O胜,0为平局# 计算损失loss_x = compute_loss(winner, "X", x_output)  # 传递模型输出给计算损失函数loss_o = compute_loss(winner, "O", o_output)  # 传递模型输出给计算损失函数# 计算损失并进行反向传播loss_x.backward()loss_o.backward()# 更新权重optimizer_x.step()optimizer_o.step()print(game)return loss_x.item(), loss_o.item()# 训练多个回合
def train(num_epochs,n):k=0for epoch in range(num_epochs):loss_x, loss_o = train_game()print(f"Epoch [{epoch+1}/{num_epochs}], Loss X: {loss_x}, Loss O: {loss_o}")k+=1if k==n:modo.save('MO')modx.save('MX')print('saved')k=0# 开始训练
train(50000,1000)

aigame.py

python">from game import Game
from mod import MyMod,input_qi#玩家下X ai下O
def playX():m=MyMod()m.load('MO')g=Game(15,15)while 1:print(g)if g.heqi() or g.win_user:breakwhile 1:try:r=input('X:')y,x=r.split(',')y=int(y)x=int(x)if g.set(y,x):breakexcept Exception as e:print(e)if g.heqi() or g.win_user:breakwhile 1:(y,x),_=input_qi(g,m)if g.set(y,x):breakprint(g)print('winner',g.win_user)#玩家下O ai下X
def playO():m=MyMod()m.load('MX')g=Game(15,15)while 1:if g.heqi() or g.win_user:breakwhile 1:(y,x),_=input_qi(g,m)if g.set(y,x):breakif g.heqi() or g.win_user:breakprint(g)while 1:try:r=input('O:')y,x=r.split(',')y=int(y)x=int(x)if g.set(y,x):breakexcept Exception as e:print(e)print(g)print('winner',g.win_user)playX()


http://www.ppmy.cn/news/1573866.html

相关文章

【数据结构-并查集】力扣1202. 交换字符串中的元素

给你一个字符串 s&#xff0c;以及该字符串中的一些「索引对」数组 pairs&#xff0c;其中 pairs[i] [a, b] 表示字符串中的两个索引&#xff08;编号从 0 开始&#xff09;。 你可以 任意多次交换 在 pairs 中任意一对索引处的字符。 返回在经过若干次交换后&#xff0c;s …

1.3 嵌入式系统的固件

以STM32F103C8T6单片机举例&#xff0c;固件代码都是放在Flash闪存中&#xff0c;以Keil界面举例&#xff0c;该界面分为启动代码&#xff0c;库函数代码&#xff0c;还有用户代码&#xff0c;编译时&#xff0c;这些代码会被编译并链接成一个单一的固件映像&#xff0c;然后通…

C语言-进程

1、进程是什么&#xff1f; 一个具有一定独立功能的程序关于某个数据集合的一次运行活动&#xff08;程序执行的过程&#xff09;&#xff0c; 是系统进行资源分配和调度运行的基本单位。是动态的&#xff0c;随着程序的使用被创建&#xff0c;随着 程序的结束而消亡。 什么是程…

Stack和Queue—模拟实现,实战应用全解析!

各位看官早安午安晚安呀 如果您觉得这篇文章对您有帮助的话 欢迎您一键三连&#xff0c;小编尽全力做到更好 欢迎您分享给更多人哦 大家好&#xff0c;我们今天来学习java数据结构的Stack和Queue&#xff08;栈和队列&#xff09; 一&#xff1a;栈 1.1&#xff1a;栈的概念 …

迪威模型网:免费畅享 3D 打印盛宴,科技魅力与趣味创意并存

还在为寻找优质3D打印模型而发愁&#xff1f;快来迪威模型网&#xff08;https://www.3dwhere.com/&#xff09;&#xff0c;一个集前沿科技与无限趣味于一体的免费3D打印宝藏平台&#xff01; 踏入迪威模型网&#xff0c;仿佛开启一场未来科技之旅。其“3D打印”专区&#xff…

【OpenCV】OpenCV 中各模块及其算子的详细分类

OpenCV 的最新版本包含了 500 多个算子&#xff0c;这些算子覆盖了图像处理、计算机视觉、机器学习、深度学习、视频分析等多个领域。为了方便使用&#xff0c;OpenCV 将这些算子分为多个模块&#xff0c;每个模块承担特定的功能。 以下是 OpenCV 中各模块及其算子的详细分类&…

“深入浅出”系列之QT:(10)Qt接入Deepseek

项目配置&#xff1a; 在.pro文件中添加网络模块&#xff1a; QT core network API配置&#xff1a; 将apiUrl替换为实际的DeepSeek API端点 将apiKey替换为你的有效API密钥 根据API文档调整请求参数&#xff08;模型名称、温度值等&#xff09; 功能说明&#xff1a; 使…

后台管理系统-月卡管理

功能说明并准备静态结构 <template><div class"card-container"><!-- 搜索区域 --><div class"search-container"><span class"search-label">车牌号码&#xff1a;</span><el-input clearable placeho…