ICLR顶会论文学习|DRL-based改进启发式求解方法JSSP

news/2025/1/23 15:05:56/

论文名:Deep Reinforcement Learning Guided Improvement Heuristic for Job Shop Scheduling

Authors: Cong Zhang, Zhiguang Cao, Wen Song, Yaoxin Wu, Jie Zh…

论文发表致:ICLR 2024

论文链接:https://doi.org/10.48550/arXiv.2211.10936

Github: https://github.com/zcaicaros/L2S

文献简介:

该文献提出了一种新的DRL-based改进启发式求解方法JSSP,其中使用GNN表示来编码完整解。该文献设计了一种基于图神经网络的表示方案,方案由两个模块组成,以有效捕获改进过程中遇到的图中动态拓扑和不同类型节点的信息。为了加快解改进过程中的评价效率,提出了一种可以同时评估多个解的消息传递机制。论文证明了该方法的计算复杂度与问题规模成线性关系。在经典基准测试上的实验表明,我们的方法学习的改进策略在很大程度上优于最先进的基于drl的方法

        这篇论文是第一次看到在shop scheduling 问题上使用DRL做heuristic improvement,但类似方法在组合优化领域已经较多,尤其在TSP及其扩展问题上,搜罗了几篇文献及开源代码如下:

1. Learning 2-opt Heuristics for the Traveling Salesman Problem via Deep Reinforcement Learning

作者:Paulo R. de O. da Costa, Jason Rhuggenaath, Yingqian Zhang, Alp Akcay

论文链接:

[2004.01608] Learning 2-opt Heuristics for the Traveling Salesman Problem via Deep Reinforcement Learning

Github开源: https://github.com/paulorocosta/learning-2opt-drl

2.Learning Improvement Heuristics for Solving Routing Problems

 作者:Yaoxin Wu, Wen Song, Zhiguang Cao, Jie Zhang, Andrew Lim, Jason Rhuggenaath, Yingqian Zhang, Alp Akcay

论文链接:

[1912.05784v2] Learning Improvement Heuristics for Solving Routing Problems

Github开源:

GitHub - yining043/TSP-improve: An improvement-based Deep Reinforcement Learning Algorithm presented in paper https://arxiv.org/abs/1912.05784v2 for solving the TSP problem.

 

整体介绍:

算法整体框架如下图:

        首先,基于调度规则生成一个初解。在对初解的迭代过程中,将解表述为析取图形式。该算法与传统启发式算法每一步都需要评估所有领域解不同,文献方法直接输出一个工序对,并根据N5领域更新当前解,不断迭代此过程直至达到终止条件。

MDP模型:

  1. 状态:每一步的状态表述为对应的析取图,析取图中对应工序节点状态向量包含该工序的最早/晚开工时间(已完工工序对应实际开工时间,若在关键路径上的工序最早、最晚开完工时间相等)、该道工序加工时间;
  2. 动作:动作即N5-领域的候选工序对;
  3. 状态转移过程(如下图所示):通过交换图中给入动作对应工序,状态即从st转移至st+1;
  4. 奖励函数:该文献的最终目的时尽可能提升初解质量:仅新解优于当前解时,奖励>0

 网络架构:

        为了利用析取图的拓扑特征有效地学习graph embedding,引入了一种新的双向图注意网络,该网络使用两个独立的模块分别嵌入析取图的前向视图和后向视图。对于每个视图,消息传播遵循各自视图的拓扑结构,聚合通过注意力机制完成

其中,前向视图可后向视图转换如下图:

网络架构对应源码如下:

import os
import syssys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))import torch
import numpy as np
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, GATConv, global_mean_pool
from torch_geometric.utils import add_self_loopsclass DGHANlayer(torch.nn.Module):def __init__(self, in_chnl, out_chnl, dropout, concat, heads=2):super(DGHANlayer, self).__init__()self.dropout = dropoutself.opsgrp_conv = GATConv(in_chnl, out_chnl, heads=heads, dropout=dropout, concat=concat)self.mchgrp_conv = GATConv(in_chnl, out_chnl, heads=heads, dropout=dropout, concat=concat)def forward(self, node_h, edge_index_pc, edge_index_mc):node_h_pc = F.elu(self.opsgrp_conv(F.dropout(node_h, p=self.dropout, training=self.training), edge_index_pc))node_h_mc = F.elu(self.mchgrp_conv(F.dropout(node_h, p=self.dropout, training=self.training), edge_index_mc))node_h = torch.mean(torch.stack([node_h_pc, node_h_mc]), dim=0, keepdim=False)return node_hclass DGHAN(torch.nn.Module):def __init__(self, in_dim, hidden_dim, dropout, layer_dghan=4, heads=2):super(DGHAN, self).__init__()self.layer_dghan = layer_dghanself.hidden_dim = hidden_dim## DGHAN conv layersself.DGHAN_layers = torch.nn.ModuleList()# init DGHAN layerif layer_dghan == 1:# only DGHAN layerself.DGHAN_layers.append(DGHANlayer(in_dim, hidden_dim, dropout, concat=False, heads=heads))else:# first DGHAN layerself.DGHAN_layers.append(DGHANlayer(in_dim, hidden_dim, dropout, concat=True, heads=heads))# following DGHAN layersfor layer in range(layer_dghan - 2):self.DGHAN_layers.append(DGHANlayer(heads * hidden_dim, hidden_dim, dropout, concat=True, heads=heads))# last DGHAN layerself.DGHAN_layers.append(DGHANlayer(heads * hidden_dim, hidden_dim, dropout, concat=False, heads=1))def forward(self, x, edge_index_pc, edge_index_mc, batch_size):# initial layer forwardh_node = self.DGHAN_layers[0](x, edge_index_pc, edge_index_mc)for layer in range(1, self.layer_dghan):h_node = self.DGHAN_layers[layer](h_node, edge_index_pc, edge_index_mc)return h_node, torch.mean(h_node.reshape(batch_size, -1, self.hidden_dim), dim=1)class GIN(torch.nn.Module):def __init__(self, in_dim, hidden_dim, layer_gin=4):super(GIN, self).__init__()self.layer_gin = layer_gin## GIN conv layersself.GIN_layers = torch.nn.ModuleList()# init gin layerself.GIN_layers.append(GINConv(Sequential(Linear(in_dim, hidden_dim),torch.nn.BatchNorm1d(hidden_dim),ReLU(),Linear(hidden_dim, hidden_dim)),eps=0,train_eps=False,aggr='mean',flow="source_to_target"))# rest gin layersfor layer in range(layer_gin - 1):self.GIN_layers.append(GINConv(Sequential(Linear(hidden_dim, hidden_dim),torch.nn.BatchNorm1d(hidden_dim),ReLU(),Linear(hidden_dim, hidden_dim)),eps=0,train_eps=False,aggr='mean',flow="source_to_target"))def forward(self, x, edge_index, batch):hidden_rep = []node_pool_over_layer = 0# initial layer forwardh = self.GIN_layers[0](x, edge_index)node_pool_over_layer += hhidden_rep.append(h)# rest layers forwardfor layer in range(1, self.layer_gin):h = self.GIN_layers[layer](h, edge_index)node_pool_over_layer += hhidden_rep.append(h)# Graph poolgPool_over_layer = 0for layer, layer_h in enumerate(hidden_rep):g_pool = global_mean_pool(layer_h, batch)gPool_over_layer += g_poolreturn node_pool_over_layer, gPool_over_layerclass Actor(nn.Module):def __init__(self,in_dim,hidden_dim,embedding_l=4,policy_l=3,embedding_type='gin',heads=4,dropout=0.6):super(Actor, self).__init__()self.embedding_l = embedding_lself.policy_l = policy_lself.embedding_type = embedding_typeif self.embedding_type == 'gin':self.embedding = GIN(in_dim=in_dim, hidden_dim=hidden_dim, layer_gin=embedding_l)elif self.embedding_type == 'dghan':self.embedding = DGHAN(in_dim=in_dim, hidden_dim=hidden_dim, dropout=dropout, layer_dghan=embedding_l, heads=heads)elif self.embedding_type == 'gin+dghan':self.embedding_gin = GIN(in_dim=in_dim, hidden_dim=hidden_dim, layer_gin=embedding_l)self.embedding_dghan = DGHAN(in_dim=in_dim, hidden_dim=hidden_dim, dropout=dropout, layer_dghan=embedding_l, heads=heads)else:raise Exception('embedding type should be either "gin", "dghan", or "gin+dghan".')# policyself.policy = torch.nn.ModuleList()if policy_l == 1:if self.embedding_type == 'gin+dghan':self.policy.append(Sequential(Linear(hidden_dim * 4, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim * 2, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:for layer in range(policy_l):if layer == 0:if self.embedding_type == 'gin+dghan':self.policy.append(Sequential(Linear(hidden_dim * 4, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim * 2, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))def forward(self, batch_states, feasible_actions):if self.embedding_type == 'gin':node_embed, graph_embed = self.embedding(batch_states.x,add_self_loops(torch.cat([batch_states.edge_index_pc,batch_states.edge_index_mc],dim=-1))[0],batch_states.batch)elif self.embedding_type == 'dghan':node_embed, graph_embed = self.embedding(batch_states.x,add_self_loops(batch_states.edge_index_pc)[0],add_self_loops(batch_states.edge_index_mc)[0],len(feasible_actions))elif self.embedding_type == 'gin+dghan':node_embed_gin, graph_embed_gin = self.embedding_gin(batch_states.x,add_self_loops(torch.cat([batch_states.edge_index_pc,batch_states.edge_index_mc],dim=-1))[0],batch_states.batch)node_embed_dghan, graph_embed_dghan = self.embedding_dghan(batch_states.x,add_self_loops(batch_states.edge_index_pc)[0],add_self_loops(batch_states.edge_index_mc)[0],len(feasible_actions))node_embed = torch.cat([node_embed_gin, node_embed_dghan], dim=-1)graph_embed = torch.cat([graph_embed_gin, graph_embed_dghan], dim=-1)else:raise Exception('embedding type should be either "gin", "dghan", or "gin+dghan".')device = node_embed.devicebatch_size = graph_embed.shape[0]n_nodes_per_state = node_embed.shape[0] // batch_size# augment node embedding with graph embedding then forwarding policynode_embed_augmented = torch.cat([node_embed, graph_embed.repeat_interleave(repeats=n_nodes_per_state, dim=0)], dim=-1).reshape(batch_size, n_nodes_per_state, -1)for layer in range(self.policy_l):node_embed_augmented = self.policy[layer](node_embed_augmented)# action scoreaction_score = torch.bmm(node_embed_augmented, node_embed_augmented.transpose(-1, -2))# prepare maskcarries = np.arange(0, batch_size * n_nodes_per_state, n_nodes_per_state)a_merge = []  # merge index of actions of all statesaction_count = []  # list of #actions for each statefor i in range(len(feasible_actions)):action_count.append(len(feasible_actions[i]))for j in range(len(feasible_actions[i])):a_merge.append([feasible_actions[i][j][0] + carries[i], feasible_actions[i][j][1]])a_merge = np.array(a_merge)mask = torch.ones(size=[batch_size * n_nodes_per_state, n_nodes_per_state], dtype=torch.bool, device=device)mask[a_merge[:, 0], a_merge[:, 1]] = Falsemask.resize_as_(action_score)# piaction_score.masked_fill_(mask, -np.inf)action_score_flat = action_score.reshape(batch_size, 1, -1)pi = F.softmax(action_score_flat, dim=-1)dist = Categorical(probs=pi)actions_id = dist.sample()# actions_id = torch.argmax(pi, dim=-1)  # greedy actionsampled_actions = [[actions_id[i].item() // n_nodes_per_state, actions_id[i].item() % n_nodes_per_state] for i in range(len(feasible_actions))]log_prob = dist.log_prob(actions_id)  # log_prob using Pytorch API, this will have a gradient shift, reference: https://github.com/pytorch/pytorch/issues/61727. Used in paper submission version.# log_prob = torch.log(torch.gather(pi, -1, actions_id.unsqueeze(-1)) + -1e-7).squeeze(-1)  # log_prob calculated manually, this will not have a gradient shift. Switch to this after paper submission.return sampled_actions, log_probif __name__ == '__main__':import randomfrom env.environment import JsspN5, BatchGraphfrom env.generateJSP import uni_instance_gendev = 'cuda' if torch.cuda.is_available() else 'cpu'n_j = 10n_m = 10l = 1h = 99reward_type = 'yaoxin'init_type = 'fdd-divide-mwkr'b_size = 2transit = 1hid_dim = 128torch.manual_seed(1)torch.cuda.manual_seed(1)np.random.seed(1)random.seed(1)env = JsspN5(n_job=n_j, n_mch=n_m, low=l, high=h, reward_type=reward_type)batch_data = BatchGraph()# instances = np.load('../test_data/syn{}x{}.npy'.format(n_j, n_m))instances = np.load('../validation_data/validation_instance_{}x{}[{},{}].npy'.format(n_j, n_m, l, h))# instances = np.array([uni_instance_gen(n_j=n_j, n_m=n_m, low=l, high=h) for _ in range(b_size)])states, feasible_as, dones = env.reset(instances=instances, init_type=init_type, device=dev)# print(env.incumbent_objs)# print(feasible_as)actor = Actor(3, hid_dim, embedding_l=4, policy_l=4, embedding_type='gin+dghan', heads=1, dropout=0.0).to(dev)while env.itr < 500:batch_data.wrapper(*states)actions, log_ps = actor(batch_data, feasible_as)states, rewards, feasible_as, dones = env.step(actions, dev)# print(actions)# print(env.incumbent_objs)# print(feasible_as)# grad = torch.autograd.grad(log_ps.sum(), [param for param in actor.parameters()])# print(env.incumbent_objs)# print(np.load('./validation_data/validation{}x{}_ortools_result.npy'.format(n_j, n_m)))optimal = np.load('../validation_data/validation{}x{}_ortools_result.npy'.format(n_j, n_m))gap = ((env.current_objs.view(-1).cpu().numpy() - optimal)/optimal).mean()print(gap)


http://www.ppmy.cn/news/1565511.html

相关文章

@Contended

Contended 是 Java 8 引入的一个注解&#xff0c;主要用于减少多线程环境下的伪共享&#xff08;False Sharing&#xff09;问题。伪共享是由于缓存行的争用导致的性能问题&#xff0c;特别是在多核处理器上。 ### Contended 注解的作用 - **减少伪共享**&#xff1a;当多个线…

循环队列(C语言)

从今天开始我会开启一个专栏leetcode每日一题&#xff0c;大家互相交流代码经验&#xff0c;也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&#…

神卓S500异地监控组网:高效、灵活的解决方案

神卓S500异地监控组网&#xff1a;高效、灵活的解决方案 随着企业规模的扩大和业务的多元化&#xff0c;异地监控的需求日益增长。神卓S500作为一款高性价比的智能组网设备&#xff0c;为企业和组织提供了高效、灵活的异地监控组网解决方案&#xff0c;满足了多场景下的监控需求…

Titans: 学习在测试时记忆 - 论文解读与总结

论文地址&#xff1a;https://arxiv.org/pdf/2501.00663v1 本文介绍了一篇由 Google Research 发表的关于新型神经网络架构 Titans 的论文&#xff0c;该架构旨在解决传统 Transformer 在处理长序列时的局限性。以下是对论文的详细解读&#xff0c;并结合原文图片进行说明&…

R语言学习笔记之开发环境配置

一、概要 整个安装过程及遇到的问题记录 操作步骤备注&#xff08;包含遇到的问题&#xff09;1下载安装R语言2下载安装RStudio3离线安装pacman提示需要安装Rtools4安装Rtoolspacman、tidyfst均离线安装完成5加载tidyfst报错 提示需要安装依赖&#xff0c;试错逐步下载并安装…

抛弃node和vscode,如何用记事本开发出一个完整的vue前端项目

写这篇文章的初衷并不是要大家真的不用node和vscode&#xff0c;说实话前端发展成今天这样&#xff0c;在实际开发中确实离不开node和vscode这类工具了&#xff0c;但往往工具用多了我们自己也成了一个工具人&#xff01; 这篇文章的缘由 最近在开发wordpress插件的时候&…

算法竞赛之二维前缀和 python

文章目录 前置知识引入为什么需要二维前缀和算法初始化二维前缀和数组求任意子矩阵元素和解决问题实战演练总结 前置知识 一维前缀和介绍&#xff1a;可以点此进入了解 引入 给你一个由 0 和 1 组成的二维数组 &#xff0c;n行m列&#xff0c;请你找出边界全部由 1 组成的最大正…

tp8读取mysql导出excel

环境&#xff1a;php8.3, thinkphp8.0, mysql8.0 use PhpOffice\PhpSpreadsheet\Spreadsheet; use PhpOffice\PhpSpreadsheet\Writer\Xlsx; use PhpOffice\PhpSpreadsheet\Style\Alignment; use think\facade\Db; use think\response\Json;class Index {public function index…