20250220-代码笔记01-class CVRPEnv

news/2025/2/24 2:57:24/

文章目录

  • 前言
  • 一、def __init__(self, **env_params):
    • 函数功能
    • 函数代码
  • 二、use_saved_problems(self, filename, device)
    • 函数功能
    • 函数代码
  • 三、load_problems(self, batch_size, aug_factor=1)
    • 函数功能
    • 函数代码
    • use_saved_problems 与 load_problems 之间的关系
  • 四、reset(self)
    • 函数功能
    • 函数代码
  • 五、pre_step(self)
    • 函数功能
    • 函数代码
  • 六、step(self, selected)
    • 函数功能
    • 函数代码
  • 七、_get_travel_distance(self)
    • 函数功能
    • 问题
      • 什么是“滚动”?
    • 函数代码
  • 附件
    • 代码(全):CVRPEnv.py
    • 代码:一、def __init__(self, **env_params)


前言

对CVRPEnv.py中的类(class CVRPEnv)代码的学习。
代码地址如下:

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py


一、def init(self, **env_params):

函数功能

这段代码是CVRPEnv类的初始化方法,主要用于初始化与**车辆路径问题(CVRP)**环境相关的各个参数和变量。

参数思维导图链接
在这里插入图片描述

函数代码

    def __init__(self, **env_params):# Const @INIT####################################self.env_params = env_paramsself.problem_size = env_params['problem_size']  #提取问题规模self.pomo_size = env_params['pomo_size']        #POMO 智能体数量self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例self.saved_depot_xy = None                      #配送中心(depot)的坐标self.saved_node_xy = None                       #节点(客户或城市)的坐标self.saved_node_demand = None                   #保存节点的需求量self.saved_index = None                         #保存节点的索引# Const @Load_Problem####################################self.batch_size = None  self.BATCH_IDX = None   self.POMO_IDX = None    # IDX.shape: (batch, pomo)self.depot_node_xy = None# shape: (batch, problem+1, 2)self.depot_node_demand = None# shape: (batch, problem+1)# Dynamic-1####################################self.selected_count = Noneself.current_node = None# shape: (batch, pomo)self.selected_node_list = None# shape: (batch, pomo, 0~)# Dynamic-2####################################self.at_the_depot = None# shape: (batch, pomo)self.load = None# shape: (batch, pomo)self.visited_ninf_flag = None# shape: (batch, pomo, problem+1)self.ninf_mask = None# shape: (batch, pomo, problem+1)self.finished = None# shape: (batch, pomo)# states to return####################################self.reset_state = Reset_State()self.step_state = Step_State()# regret####################################self.mode = Noneself.last_current_node = Noneself.last_load = Noneself.regret_count = Noneself.regret_mask_matrix = Noneself.add_mask_matrix = Noneself.time_step=0 

二、use_saved_problems(self, filename, device)

函数功能

函数的功能是加载预先保存的问题实例,并将这些问题实例的数据保存到类的属性中。
它会从指定的文件中读取问题数据,包括配送中心的位置(depot_xy)节点的位置(node_xy)节点的需求量(node_demand),然后将这些数据存储在类的属性中,以供后续使用。

函数思维导图链接
在这里插入图片描述

函数代码

 def use_saved_problems(self, filename, device):                self.FLAG__use_saved_problems = True loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例self.saved_depot_xy = loaded_dict['depot_xy']           #解析加载的数据self.saved_node_xy = loaded_dict['node_xy']             #self.saved_node_demand = loaded_dict['node_demand']self.saved_index = 0

三、load_problems(self, batch_size, aug_factor=1)

函数功能

该函数用于加载**车辆路径问题(CVRP)**实例,包括:

  1. 动态生成问题实例 或 从预加载数据中提取问题
  2. 数据增强
  3. 初始化索引和状态变量
  4. 存储到环境变量

工作方式

  • 如果 self.FLAG__use_saved_problemsTrue,则从通过 use_saved_problems 加载的预先保存的问题实例中提取数据(self.saved_depot_xy, self.saved_node_xy, self.saved_node_demand),并更新索引 self.saved_index
  • 如果 self.FLAG__use_saved_problemsFalse,则动态生成问题实例。使用 get_random_problems() 方法生成指定 batch_sizeproblem_size 的问题数据。
  • load_problems 还支持数据增强,通过指定 aug_factor 来增强生成的数据(目前仅支持 aug_factor=8),扩展批次数量并改变问题实例的坐标和需求。

函数功能思维导图链接
在这里插入图片描述

函数代码

 def load_problems(self, batch_size, aug_factor=1):self.batch_size = batch_size#加载问题实例if not self.FLAG__use_saved_problems:#动态生成模式depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)else:#预加载模式,从保存的实例数据中提取问题depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]self.saved_index += batch_size#数据增强if aug_factor > 1:if aug_factor == 8:self.batch_size = self.batch_size * 8depot_xy = augment_xy_data_by_8_fold(depot_xy)node_xy = augment_xy_data_by_8_fold(node_xy)node_demand = node_demand.repeat(8, 1)else:raise NotImplementedError#合并配送中心和节点数据self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)# shape: (batch, problem+1, 2)depot_demand = torch.zeros(size=(self.batch_size, 1))# shape: (batch, 1)self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)# shape: (batch, problem+1)#初始化批量索引和 POMO 索引self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)#更新重置状态和步骤状态self.reset_state.depot_xy = depot_xyself.reset_state.node_xy = node_xyself.reset_state.node_demand = node_demandself.step_state.BATCH_IDX = self.BATCH_IDXself.step_state.POMO_IDX = self.POMO_IDX

use_saved_problems 与 load_problems 之间的关系

  • use_saved_problems 作为数据加载的前置条件

    • use_saved_problems 主要负责加载已经保存好的问题实例文件(比如一个torch.save()保存的文件),并将这些数据存储到环境中的特定变量中(例如 self.saved_depot_xyself.saved_node_xy)。

    • 一旦执行了use_saved_problems,它设置了 self.FLAG__use_saved_problems = True,这意味着在后续的操作中,环境会从保存的数据中加载问题实例,而不是重新生成问题。

    • 但是use_saved_problems 本身并不负责加载具体的问题实例数据它只是为后续的加载操作(如 load_problems)提供了指示标志

  • load_problems使用 use_saved_problems 加载的数据:

    • load_problems执行数据加载和问题生成的主函数,它根据 self.FLAG__use_saved_problems 的值,决定是从保存的数据中提取问题实例,还是生成新的随机问题实例。
    • self.FLAG__use_saved_problems = True 时,load_problems 会从 self.saved_depot_xyself.saved_node_xyself.saved_node_demand 等变量中读取数据,并根据需要为每个批次的问题实例做进一步处理(如索引的更新、数据增强等)。
    • 如果 self.FLAG__use_saved_problems = False,则 load_problems 会使用 get_random_problems() 来动态生成问题数据。

四、reset(self)

函数功能

reset 函数的主要目的是将环境的状态变量重置为初始值,通常在每个新的训练回合或实验开始时调用。该函数确保环境处于一个已知的初始状态,以便智能体能够从一个干净的状态开始进行决策和学习。

函数参数思维导图
在这里插入图片描述

函数代码

 def reset(self):#重置选择计数self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)#重置当前节点self.current_node = None# shape: (batch, pomo)  #重置已选择的节点列表self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)# shape: (batch, pomo, 0~)#初始化是否在配送中心self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)# shape: (batch, pomo)# 初始化负载self.load = torch.ones(size=(self.batch_size, self.pomo_size))# shape: (batch, pomo)#初始化访问掩码self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')# shape: (batch, pomo, problem+1)#初始化负无穷掩码self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))self.ninf_mask[:, :, self.problem_size+1] = float('-inf')# shape: (batch, pomo, problem+1)#初始化完成状态self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)# shape: (batch, pomo)#初始化其他状态变量self.regret_count = torch.zeros((self.batch_size, self.pomo_size))self.mode = torch.full((self.batch_size, self.pomo_size), 0)self.last_current_node = Noneself.last_load = Noneself.time_step=0reward = Nonedone = Falsereturn self.reset_state, reward, done

五、pre_step(self)

函数功能

pre_step 函数是环境中的一个预处理步骤,用于在每个时间步之前设置必要的状态信息。
通常,在强化学习环境中,每个时间步会根据当前状态和动作进行更新,pre_step 函数则为每个时间步提供所需的状态,供后续的决策和学习过程使用。

函数功能思维导图
在这里插入图片描述

函数代码

    def pre_step(self):#重置 selected_countself.step_state.selected_count = 0#复制当前负载self.step_state.load = self.load#设置当前节点self.step_state.current_node = self.current_node#更新掩码状态self.step_state.ninf_mask = self.ninf_mask#返回步骤状态、奖励和完成标志reward = Nonedone = Falsereturn self.step_state, reward, done

六、step(self, selected)

函数功能

这个函数的主要功能是在每个时间步(step)中更新智能体的状态,执行任务、处理负载、选择节点等,最终返回当前的状态、奖励和是否完成任务的标志。

函数功能与参数的思维导图链接

在这里插入图片描述

函数代码

def step(self, selected):# selected.shape: (batch, pomo)#时间步数控制if self.time_step<4:# 控制时间步的递增self.time_step=self.time_step+1self.selectex_count = self.selected_count+1#判断是否在配送中心self.at_the_depot = (selected == 0)#特定时间步的操作if self.time_step==3:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()if self.time_step == 4:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0#更新当前节点和已选择节点列表self.current_node = selectedself.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)#更新需求和负载demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)gathering_index = selected[:, :, None]selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)self.load -= selected_demandself.load[self.at_the_depot] = 1  # refill loaded at the depot#更新访问标记(防止重复选择已访问的节点)self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot#更新负无穷掩码(屏蔽需求量超过当前负载的节点)self.ninf_mask = self.visited_ninf_flag.clone()round_error_epsilon = 0.00001demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)demand_too_large = torch.cat((demand_too_large, _2), dim=2)self.ninf_mask[demand_too_large] = float('-inf')#更新步骤状态,将更新后的状态同步到 self.step_stateself.step_state.selected_count = self.time_stepself.step_state.load = self.loadself.step_state.current_node = self.current_nodeself.step_state.ninf_mask = self.ninf_mask#时间步大于等于 4 的复杂操作else:#动作模式分类action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regretaction2_bool_index = self.mode == 1action3_bool_index = self.mode == 2action1_index = torch.nonzero(action1_bool_index)action2_index = torch.nonzero(action2_bool_index)action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))#更新选择计数self.selected_count = self.selected_count+1#后悔模式self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2#节点更新self.last_is_depot = (self.last_current_node == 0)_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()self.last_current_node = self.current_node.clone()self.current_node = selected.clone()self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()#更新已选择节点列表self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)#更新负载self.at_the_depot = (selected == 0)demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)# shape: (batch, pomo, problem+1)_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)#扩展需求列表 demand_list demand_list = torch.cat((demand_list, _3), dim=2)gathering_index = selected[:, :, None]# shape: (batch, pomo, 1)selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()self.last_load= self.load.clone()# shape: (batch, pomo)self.load -= selected_demandself.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()self.load[self.at_the_depot] = 1  # refill loaded at the depot#更新访问标记self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0# 更新负无穷掩码self.ninf_mask = self.visited_ninf_flag.clone()round_error_epsilon = 0.00001demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list# shape: (batch, pomo, problem+1)self.ninf_mask[demand_too_large] = float('-inf')# 更新完成状态# 检查哪些智能体已经完成所有节点的访问。# 更新完成标记 self.finished。newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)# shape: (batch, pomo)self.finished = self.finished + newly_finished# shape: (batch, pomo)#更新模式self.mode[action1_bool_index] = 1self.mode[action2_bool_index] = 2self.mode[action3_bool_index] = 0self.mode[self.finished] = 4# 更新完成后的掩码调整self.ninf_mask[:, :, 0][self.finished] = 0self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')# 更新步骤状态self.step_state.selected_count = self.time_stepself.step_state.load = self.loadself.step_state.current_node = self.current_nodeself.step_state.ninf_mask = self.ninf_mask# returning valuesdone = self.finished.all()if done:reward = -self._get_travel_distance()  # note the minus sign!else:reward = Nonereturn self.step_state, reward, done

七、_get_travel_distance(self)

函数功能

_get_travel_distance 函数的主要功能是计算每个智能体(POMO智能体)在每个时间步所选择的节点之间的旅行距离。

函数参数和流程图链接

在这里插入图片描述

问题

什么是“滚动”?

“滚动”是对张量或数组进行操作的一种方式,它通过沿特定维度(通常是时间维度)移动元素,从而生成一个新的数组或张量。

例子
设我们有一个一维张量表示时间步的节点选择情况:

tensor = torch.tensor([1, 2, 3, 4, 5])

如果我们对这个张量进行滚动操作,沿着时间维度向右滚动1步:

rolled_tensor = tensor.roll(dims=0, shifts=1)

这时,rolled_tensor 将变成:

tensor([5, 1, 2, 3, 4])

函数代码

  def _get_travel_distance(self):m1 = (self.selected_node_list==self.problem_size+1)m2 = (m1.roll(dims=2, shifts=-1) | m1)m3 = m1.roll(dims=2, shifts=1)m4 = ~(m2|m3)selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)self.regret_mask_matrix = m1self.add_mask_matrix = (~m2)travel_distances = torch.zeros((self.batch_size, self.pomo_size))for t in range(self.selected_node_list.shape[2]):add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()return travel_distances

附件

代码(全):CVRPEnv.py

返回:前言

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py


from dataclasses import dataclass
import torchfrom CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold@dataclass
class Reset_State:depot_xy: torch.Tensor = None# shape: (batch, 1, 2)node_xy: torch.Tensor = None# shape: (batch, problem, 2)node_demand: torch.Tensor = None# shape: (batch, problem)@dataclass
class Step_State:BATCH_IDX: torch.Tensor = None      #表示批次的索引POMO_IDX: torch.Tensor = None       #表示 POMO 算法中的多智能体索引# shape: (batch, pomo)selected_count: int = None          #表示当前已经选中的节点数量load: torch.Tensor = None           #表示当前负载状态# shape: (batch, pomo)current_node: torch.Tensor = None   #表示当前正在访问的节点编号# shape: (batch, pomo)ninf_mask: torch.Tensor = None      #表示负无穷掩码# shape: (batch, pomo, problem+1)class CVRPEnv:               def __init__(self, **env_params):# Const @INIT####################################self.env_params = env_paramsself.problem_size = env_params['problem_size']  #提取问题规模self.pomo_size = env_params['pomo_size']        #POMO 智能体数量self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例self.saved_depot_xy = None                      #配送中心(depot)的坐标self.saved_node_xy = None                       #节点(客户或城市)的坐标self.saved_node_demand = None                   #保存节点的需求量self.saved_index = None                         #保存节点的索引# Const @Load_Problem####################################self.batch_size = None  self.BATCH_IDX = None   self.POMO_IDX = None    # IDX.shape: (batch, pomo)self.depot_node_xy = None# shape: (batch, problem+1, 2)self.depot_node_demand = None# shape: (batch, problem+1)# Dynamic-1####################################self.selected_count = Noneself.current_node = None# shape: (batch, pomo)self.selected_node_list = None# shape: (batch, pomo, 0~)# Dynamic-2####################################self.at_the_depot = None# shape: (batch, pomo)self.load = None# shape: (batch, pomo)self.visited_ninf_flag = None# shape: (batch, pomo, problem+1)self.ninf_mask = None# shape: (batch, pomo, problem+1)self.finished = None# shape: (batch, pomo)# states to return####################################self.reset_state = Reset_State()self.step_state = Step_State()# regret####################################self.mode = Noneself.last_current_node = Noneself.last_load = Noneself.regret_count = Noneself.regret_mask_matrix = Noneself.add_mask_matrix = Noneself.time_step=0#加载保存的问题实例数据 def use_saved_problems(self, filename, device):                self.FLAG__use_saved_problems = True loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例self.saved_depot_xy = loaded_dict['depot_xy']           #解析加载的数据self.saved_node_xy = loaded_dict['node_xy']             #self.saved_node_demand = loaded_dict['node_demand']self.saved_index = 0def load_problems(self, batch_size, aug_factor=1):self.batch_size = batch_size#加载问题实例if not self.FLAG__use_saved_problems:#动态生成模式depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)else:#预加载模式,从保存的实例数据中提取问题depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]self.saved_index += batch_size#数据增强if aug_factor > 1:if aug_factor == 8:self.batch_size = self.batch_size * 8depot_xy = augment_xy_data_by_8_fold(depot_xy)node_xy = augment_xy_data_by_8_fold(node_xy)node_demand = node_demand.repeat(8, 1)else:raise NotImplementedError#合并配送中心和节点数据self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)# shape: (batch, problem+1, 2)depot_demand = torch.zeros(size=(self.batch_size, 1))# shape: (batch, 1)self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)# shape: (batch, problem+1)#初始化批量索引和 POMO 索引self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)#更新重置状态和步骤状态self.reset_state.depot_xy = depot_xyself.reset_state.node_xy = node_xyself.reset_state.node_demand = node_demandself.step_state.BATCH_IDX = self.BATCH_IDXself.step_state.POMO_IDX = self.POMO_IDXdef reset(self):#重置选择计数self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)#重置当前节点self.current_node = None# shape: (batch, pomo)  #重置已选择的节点列表self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)# shape: (batch, pomo, 0~)#初始化是否在配送中心self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)# shape: (batch, pomo)# 初始化负载self.load = torch.ones(size=(self.batch_size, self.pomo_size))# shape: (batch, pomo)#初始化访问掩码self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')# shape: (batch, pomo, problem+1)#初始化负无穷掩码self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))self.ninf_mask[:, :, self.problem_size+1] = float('-inf')# shape: (batch, pomo, problem+1)#初始化完成状态self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)# shape: (batch, pomo)#初始化其他状态变量self.regret_count = torch.zeros((self.batch_size, self.pomo_size))self.mode = torch.full((self.batch_size, self.pomo_size), 0)self.last_current_node = Noneself.last_load = Noneself.time_step=0reward = Nonedone = Falsereturn self.reset_state, reward, donedef pre_step(self):#重置 selected_countself.step_state.selected_count = 0#复制当前负载self.step_state.load = self.load#设置当前节点self.step_state.current_node = self.current_node#更新掩码状态self.step_state.ninf_mask = self.ninf_mask#返回步骤状态、奖励和完成标志reward = Nonedone = Falsereturn self.step_state, reward, donedef step(self, selected):# selected.shape: (batch, pomo)#时间步数控制if self.time_step<4:# 控制时间步的递增self.time_step=self.time_step+1self.selectex_count = self.selected_count+1#判断是否在配送中心self.at_the_depot = (selected == 0)#特定时间步的操作if self.time_step==3:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()if self.time_step == 4:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0#更新当前节点和已选择节点列表self.current_node = selectedself.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)#更新需求和负载demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)gathering_index = selected[:, :, None]selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)self.load -= selected_demandself.load[self.at_the_depot] = 1  # refill loaded at the depot#更新访问标记(防止重复选择已访问的节点)self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot#更新负无穷掩码(屏蔽需求量超过当前负载的节点)self.ninf_mask = self.visited_ninf_flag.clone()round_error_epsilon = 0.00001demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)demand_too_large = torch.cat((demand_too_large, _2), dim=2)self.ninf_mask[demand_too_large] = float('-inf')#更新步骤状态,将更新后的状态同步到 self.step_stateself.step_state.selected_count = self.time_stepself.step_state.load = self.loadself.step_state.current_node = self.current_nodeself.step_state.ninf_mask = self.ninf_mask#时间步大于等于 4 的复杂操作else:#动作模式分类action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regretaction2_bool_index = self.mode == 1action3_bool_index = self.mode == 2action1_index = torch.nonzero(action1_bool_index)action2_index = torch.nonzero(action2_bool_index)action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))#更新选择计数self.selected_count = self.selected_count+1#后悔模式self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2#节点更新self.last_is_depot = (self.last_current_node == 0)_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()self.last_current_node = self.current_node.clone()self.current_node = selected.clone()self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()#更新已选择节点列表self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)#更新负载self.at_the_depot = (selected == 0)demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)# shape: (batch, pomo, problem+1)_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)#扩展需求列表 demand_list demand_list = torch.cat((demand_list, _3), dim=2)gathering_index = selected[:, :, None]# shape: (batch, pomo, 1)selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()self.last_load= self.load.clone()# shape: (batch, pomo)self.load -= selected_demandself.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()self.load[self.at_the_depot] = 1  # refill loaded at the depot#更新访问标记self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0# 更新负无穷掩码self.ninf_mask = self.visited_ninf_flag.clone()round_error_epsilon = 0.00001demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list# shape: (batch, pomo, problem+1)self.ninf_mask[demand_too_large] = float('-inf')# 更新完成状态# 检查哪些智能体已经完成所有节点的访问。# 更新完成标记 self.finished。newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)# shape: (batch, pomo)self.finished = self.finished + newly_finished# shape: (batch, pomo)#更新模式self.mode[action1_bool_index] = 1self.mode[action2_bool_index] = 2self.mode[action3_bool_index] = 0self.mode[self.finished] = 4# 更新完成后的掩码调整self.ninf_mask[:, :, 0][self.finished] = 0self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')# 更新步骤状态self.step_state.selected_count = self.time_stepself.step_state.load = self.loadself.step_state.current_node = self.current_nodeself.step_state.ninf_mask = self.ninf_mask# returning valuesdone = self.finished.all()if done:reward = -self._get_travel_distance()  # note the minus sign!else:reward = Nonereturn self.step_state, reward, donedef _get_travel_distance(self):m1 = (self.selected_node_list==self.problem_size+1)m2 = (m1.roll(dims=2, shifts=-1) | m1)m3 = m1.roll(dims=2, shifts=1)m4 = ~(m2|m3)selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)self.regret_mask_matrix = m1self.add_mask_matrix = (~m2)travel_distances = torch.zeros((self.batch_size, self.pomo_size))for t in range(self.selected_node_list.shape[2]):add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()return travel_distances

代码:一、def init(self, **env_params)

    def __init__(self, **env_params):# Const @INIT####################################self.env_params = env_paramsself.problem_size = env_params['problem_size']  #提取问题规模self.pomo_size = env_params['pomo_size']        #POMO 智能体数量self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例self.saved_depot_xy = None                      #配送中心(depot)的坐标self.saved_node_xy = None                       #节点(客户或城市)的坐标self.saved_node_demand = None                   #保存节点的需求量self.saved_index = None                         #保存节点的索引# Const @Load_Problem####################################self.batch_size = None  self.BATCH_IDX = None   self.POMO_IDX = None    # IDX.shape: (batch, pomo)self.depot_node_xy = None# shape: (batch, problem+1, 2)self.depot_node_demand = None# shape: (batch, problem+1)# Dynamic-1####################################self.selected_count = Noneself.current_node = None# shape: (batch, pomo)self.selected_node_list = None# shape: (batch, pomo, 0~)# Dynamic-2####################################self.at_the_depot = None# shape: (batch, pomo)self.load = None# shape: (batch, pomo)self.visited_ninf_flag = None# shape: (batch, pomo, problem+1)self.ninf_mask = None# shape: (batch, pomo, problem+1)self.finished = None# shape: (batch, pomo)# states to return####################################self.reset_state = Reset_State()self.step_state = Step_State()# regret####################################self.mode = Noneself.last_current_node = Noneself.last_load = Noneself.regret_count = Noneself.regret_mask_matrix = Noneself.add_mask_matrix = Noneself.time_step=0 

http://www.ppmy.cn/news/1574527.html

相关文章

【论文笔记】Mamba: Linear-time sequence modeling with selective state spaces

【引用格式】&#xff1a;Gu A, Dao T. Mamba: Linear-time sequence modeling with selective state spaces[J]. arXiv preprint arXiv:2312.00752, 2023. 【网址】&#xff1a;https://arxiv.org/pdf/2312.00752 【开源代码】&#xff1a;https://github.com/state-spaces/…

探索火山引擎 DeepSeek-R1 满血版:流畅、高效的 AI 开发体验

方舟大模型体验中心全新上线&#xff0c;免登录体验满血联网版Deep Seek R1 模型及豆包最新版模型》 https://www.volcengine.com/experience/ark?utm_term202502dsinvite&acDSASUQY5&rcA4K514ZC 大家好&#xff01;作为一名开发者&#xff0c;我一直在寻找能够提升…

Spring Boot 中多线程工具类的配置与使用:基于 YAML 配置文件

文章目录 Spring Boot 中多线程工具类的配置与使用&#xff1a;基于 YAML 配置文件1. 为什么需要多线程工具类&#xff1f;2. 实现步骤2.1 添加依赖2.2 配置线程池参数2.3 创建配置类2.4 创建线程池工具类2.5 使用线程池工具类2.6 测试线程池工具类 3. 配置文件的灵活性4. 总结…

低空经济应用场景细分赛道探索,无人机开源飞控二次开发详解

低空经济应用场景细分赛道探索 低空经济作为一个新兴的经济形态&#xff0c;随着直升机、eVTOL&#xff08;电动垂直起降飞行器&#xff09;、无人机等新技术新产品的快速发展&#xff0c;其应用范围正向第一、第二、第三产业中的多个领域迅速拓展。以下是对低空经济应用场景细…

蓝桥杯 Java B 组之岛屿数量、二叉树路径和(区分DFS与回溯)

Day 3&#xff1a;岛屿数量、二叉树路径和&#xff08;区分DFS与回溯&#xff09; &#x1f4d6; 一、深度优先搜索&#xff08;DFS&#xff09;简介 深度优先搜索&#xff08;Depth-First Search&#xff0c;DFS&#xff09;是一种用于遍历或搜索树或图的算法。它会沿着树的分…

【洛谷排序算法】P1012拼数-详细讲解

洛谷 P1012 拼数这道题本身并非单纯考察某种经典排序算法&#xff08;如冒泡排序、选择排序、插入排序、快速排序、归并排序等&#xff09;的实现&#xff0c;而是在排序的基础上&#xff0c;自定义了排序的比较规则&#xff0c;属于自定义排序类型的题目。不过它借助了标准库中…

安装Bash completion解决tab不能补全问题

Bash completion 是一个强大的功能&#xff0c;它可以帮助你在 Bash shell 中自动补全命令、文件名、选项等。默认情况下&#xff0c;Bash completion 应该在所有用户&#xff08;包括 root 用户&#xff09;下都能工作。不过&#xff0c;如果你发现 root 用户下没有启用 Bash …

【C++游戏开发-五子棋】

使用C开发五子棋游戏的详细实现方案&#xff0c;涵盖核心逻辑、界面设计和AI对战功能&#xff1a; 1. 项目结构 FiveChess/ ├── include/ │ ├── Board.h // 棋盘类 │ ├── Player.h // 玩家类 │ ├── AI.h // AI类 │ └── Game.h // 游戏主逻辑 ├── src/ …