论文名:Deep Reinforcement Learning Guided Improvement Heuristic for Job Shop Scheduling
Authors: Cong Zhang, Zhiguang Cao, Wen Song, Yaoxin Wu, Jie Zh…
论文发表致:ICLR 2024
论文链接:https://doi.org/10.48550/arXiv.2211.10936
Github: https://github.com/zcaicaros/L2S
文献简介:
该文献提出了一种新的DRL-based改进启发式求解方法JSSP,其中使用GNN表示来编码完整解。该文献设计了一种基于图神经网络的表示方案,方案由两个模块组成,以有效捕获改进过程中遇到的图中动态拓扑和不同类型节点的信息。为了加快解改进过程中的评价效率,提出了一种可以同时评估多个解的消息传递机制。论文证明了该方法的计算复杂度与问题规模成线性关系。在经典基准测试上的实验表明,我们的方法学习的改进策略在很大程度上优于最先进的基于drl的方法。
这篇论文是第一次看到在shop scheduling 问题上使用DRL做heuristic improvement,但类似方法在组合优化领域已经较多,尤其在TSP及其扩展问题上,搜罗了几篇文献及开源代码如下:
1. Learning 2-opt Heuristics for the Traveling Salesman Problem via Deep Reinforcement Learning
作者:Paulo R. de O. da Costa, Jason Rhuggenaath, Yingqian Zhang, Alp Akcay
论文链接:
[2004.01608] Learning 2-opt Heuristics for the Traveling Salesman Problem via Deep Reinforcement Learning
Github开源: https://github.com/paulorocosta/learning-2opt-drl
2.Learning Improvement Heuristics for Solving Routing Problems
作者:Yaoxin Wu, Wen Song, Zhiguang Cao, Jie Zhang, Andrew Lim, Jason Rhuggenaath, Yingqian Zhang, Alp Akcay
论文链接:
[1912.05784v2] Learning Improvement Heuristics for Solving Routing Problems
Github开源:
GitHub - yining043/TSP-improve: An improvement-based Deep Reinforcement Learning Algorithm presented in paper https://arxiv.org/abs/1912.05784v2 for solving the TSP problem.
整体介绍:
算法整体框架如下图:
首先,基于调度规则生成一个初解。在对初解的迭代过程中,将解表述为析取图形式。该算法与传统启发式算法每一步都需要评估所有领域解不同,文献方法直接输出一个工序对,并根据N5领域更新当前解,不断迭代此过程直至达到终止条件。
MDP模型:
- 状态:每一步的状态表述为对应的析取图,析取图中对应工序节点状态向量包含该工序的最早/晚开工时间(已完工工序对应实际开工时间,若在关键路径上的工序最早、最晚开完工时间相等)、该道工序加工时间;
- 动作:动作即N5-领域的候选工序对;
- 状态转移过程(如下图所示):通过交换图中给入动作对应工序,状态即从st转移至st+1;
- 奖励函数:该文献的最终目的时尽可能提升初解质量:仅新解优于当前解时,奖励>0
网络架构:
为了利用析取图的拓扑特征有效地学习graph embedding,引入了一种新的双向图注意网络,该网络使用两个独立的模块分别嵌入析取图的前向视图和后向视图。对于每个视图,消息传播遵循各自视图的拓扑结构,聚合通过注意力机制完成。
其中,前向视图可后向视图转换如下图:
网络架构对应源码如下:
import os
import syssys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))import torch
import numpy as np
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, GATConv, global_mean_pool
from torch_geometric.utils import add_self_loopsclass DGHANlayer(torch.nn.Module):def __init__(self, in_chnl, out_chnl, dropout, concat, heads=2):super(DGHANlayer, self).__init__()self.dropout = dropoutself.opsgrp_conv = GATConv(in_chnl, out_chnl, heads=heads, dropout=dropout, concat=concat)self.mchgrp_conv = GATConv(in_chnl, out_chnl, heads=heads, dropout=dropout, concat=concat)def forward(self, node_h, edge_index_pc, edge_index_mc):node_h_pc = F.elu(self.opsgrp_conv(F.dropout(node_h, p=self.dropout, training=self.training), edge_index_pc))node_h_mc = F.elu(self.mchgrp_conv(F.dropout(node_h, p=self.dropout, training=self.training), edge_index_mc))node_h = torch.mean(torch.stack([node_h_pc, node_h_mc]), dim=0, keepdim=False)return node_hclass DGHAN(torch.nn.Module):def __init__(self, in_dim, hidden_dim, dropout, layer_dghan=4, heads=2):super(DGHAN, self).__init__()self.layer_dghan = layer_dghanself.hidden_dim = hidden_dim## DGHAN conv layersself.DGHAN_layers = torch.nn.ModuleList()# init DGHAN layerif layer_dghan == 1:# only DGHAN layerself.DGHAN_layers.append(DGHANlayer(in_dim, hidden_dim, dropout, concat=False, heads=heads))else:# first DGHAN layerself.DGHAN_layers.append(DGHANlayer(in_dim, hidden_dim, dropout, concat=True, heads=heads))# following DGHAN layersfor layer in range(layer_dghan - 2):self.DGHAN_layers.append(DGHANlayer(heads * hidden_dim, hidden_dim, dropout, concat=True, heads=heads))# last DGHAN layerself.DGHAN_layers.append(DGHANlayer(heads * hidden_dim, hidden_dim, dropout, concat=False, heads=1))def forward(self, x, edge_index_pc, edge_index_mc, batch_size):# initial layer forwardh_node = self.DGHAN_layers[0](x, edge_index_pc, edge_index_mc)for layer in range(1, self.layer_dghan):h_node = self.DGHAN_layers[layer](h_node, edge_index_pc, edge_index_mc)return h_node, torch.mean(h_node.reshape(batch_size, -1, self.hidden_dim), dim=1)class GIN(torch.nn.Module):def __init__(self, in_dim, hidden_dim, layer_gin=4):super(GIN, self).__init__()self.layer_gin = layer_gin## GIN conv layersself.GIN_layers = torch.nn.ModuleList()# init gin layerself.GIN_layers.append(GINConv(Sequential(Linear(in_dim, hidden_dim),torch.nn.BatchNorm1d(hidden_dim),ReLU(),Linear(hidden_dim, hidden_dim)),eps=0,train_eps=False,aggr='mean',flow="source_to_target"))# rest gin layersfor layer in range(layer_gin - 1):self.GIN_layers.append(GINConv(Sequential(Linear(hidden_dim, hidden_dim),torch.nn.BatchNorm1d(hidden_dim),ReLU(),Linear(hidden_dim, hidden_dim)),eps=0,train_eps=False,aggr='mean',flow="source_to_target"))def forward(self, x, edge_index, batch):hidden_rep = []node_pool_over_layer = 0# initial layer forwardh = self.GIN_layers[0](x, edge_index)node_pool_over_layer += hhidden_rep.append(h)# rest layers forwardfor layer in range(1, self.layer_gin):h = self.GIN_layers[layer](h, edge_index)node_pool_over_layer += hhidden_rep.append(h)# Graph poolgPool_over_layer = 0for layer, layer_h in enumerate(hidden_rep):g_pool = global_mean_pool(layer_h, batch)gPool_over_layer += g_poolreturn node_pool_over_layer, gPool_over_layerclass Actor(nn.Module):def __init__(self,in_dim,hidden_dim,embedding_l=4,policy_l=3,embedding_type='gin',heads=4,dropout=0.6):super(Actor, self).__init__()self.embedding_l = embedding_lself.policy_l = policy_lself.embedding_type = embedding_typeif self.embedding_type == 'gin':self.embedding = GIN(in_dim=in_dim, hidden_dim=hidden_dim, layer_gin=embedding_l)elif self.embedding_type == 'dghan':self.embedding = DGHAN(in_dim=in_dim, hidden_dim=hidden_dim, dropout=dropout, layer_dghan=embedding_l, heads=heads)elif self.embedding_type == 'gin+dghan':self.embedding_gin = GIN(in_dim=in_dim, hidden_dim=hidden_dim, layer_gin=embedding_l)self.embedding_dghan = DGHAN(in_dim=in_dim, hidden_dim=hidden_dim, dropout=dropout, layer_dghan=embedding_l, heads=heads)else:raise Exception('embedding type should be either "gin", "dghan", or "gin+dghan".')# policyself.policy = torch.nn.ModuleList()if policy_l == 1:if self.embedding_type == 'gin+dghan':self.policy.append(Sequential(Linear(hidden_dim * 4, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim * 2, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:for layer in range(policy_l):if layer == 0:if self.embedding_type == 'gin+dghan':self.policy.append(Sequential(Linear(hidden_dim * 4, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim * 2, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))else:self.policy.append(Sequential(Linear(hidden_dim, hidden_dim),# torch.nn.BatchNorm1d(hidden_dim),torch.nn.Tanh(),Linear(hidden_dim, hidden_dim)))def forward(self, batch_states, feasible_actions):if self.embedding_type == 'gin':node_embed, graph_embed = self.embedding(batch_states.x,add_self_loops(torch.cat([batch_states.edge_index_pc,batch_states.edge_index_mc],dim=-1))[0],batch_states.batch)elif self.embedding_type == 'dghan':node_embed, graph_embed = self.embedding(batch_states.x,add_self_loops(batch_states.edge_index_pc)[0],add_self_loops(batch_states.edge_index_mc)[0],len(feasible_actions))elif self.embedding_type == 'gin+dghan':node_embed_gin, graph_embed_gin = self.embedding_gin(batch_states.x,add_self_loops(torch.cat([batch_states.edge_index_pc,batch_states.edge_index_mc],dim=-1))[0],batch_states.batch)node_embed_dghan, graph_embed_dghan = self.embedding_dghan(batch_states.x,add_self_loops(batch_states.edge_index_pc)[0],add_self_loops(batch_states.edge_index_mc)[0],len(feasible_actions))node_embed = torch.cat([node_embed_gin, node_embed_dghan], dim=-1)graph_embed = torch.cat([graph_embed_gin, graph_embed_dghan], dim=-1)else:raise Exception('embedding type should be either "gin", "dghan", or "gin+dghan".')device = node_embed.devicebatch_size = graph_embed.shape[0]n_nodes_per_state = node_embed.shape[0] // batch_size# augment node embedding with graph embedding then forwarding policynode_embed_augmented = torch.cat([node_embed, graph_embed.repeat_interleave(repeats=n_nodes_per_state, dim=0)], dim=-1).reshape(batch_size, n_nodes_per_state, -1)for layer in range(self.policy_l):node_embed_augmented = self.policy[layer](node_embed_augmented)# action scoreaction_score = torch.bmm(node_embed_augmented, node_embed_augmented.transpose(-1, -2))# prepare maskcarries = np.arange(0, batch_size * n_nodes_per_state, n_nodes_per_state)a_merge = [] # merge index of actions of all statesaction_count = [] # list of #actions for each statefor i in range(len(feasible_actions)):action_count.append(len(feasible_actions[i]))for j in range(len(feasible_actions[i])):a_merge.append([feasible_actions[i][j][0] + carries[i], feasible_actions[i][j][1]])a_merge = np.array(a_merge)mask = torch.ones(size=[batch_size * n_nodes_per_state, n_nodes_per_state], dtype=torch.bool, device=device)mask[a_merge[:, 0], a_merge[:, 1]] = Falsemask.resize_as_(action_score)# piaction_score.masked_fill_(mask, -np.inf)action_score_flat = action_score.reshape(batch_size, 1, -1)pi = F.softmax(action_score_flat, dim=-1)dist = Categorical(probs=pi)actions_id = dist.sample()# actions_id = torch.argmax(pi, dim=-1) # greedy actionsampled_actions = [[actions_id[i].item() // n_nodes_per_state, actions_id[i].item() % n_nodes_per_state] for i in range(len(feasible_actions))]log_prob = dist.log_prob(actions_id) # log_prob using Pytorch API, this will have a gradient shift, reference: https://github.com/pytorch/pytorch/issues/61727. Used in paper submission version.# log_prob = torch.log(torch.gather(pi, -1, actions_id.unsqueeze(-1)) + -1e-7).squeeze(-1) # log_prob calculated manually, this will not have a gradient shift. Switch to this after paper submission.return sampled_actions, log_probif __name__ == '__main__':import randomfrom env.environment import JsspN5, BatchGraphfrom env.generateJSP import uni_instance_gendev = 'cuda' if torch.cuda.is_available() else 'cpu'n_j = 10n_m = 10l = 1h = 99reward_type = 'yaoxin'init_type = 'fdd-divide-mwkr'b_size = 2transit = 1hid_dim = 128torch.manual_seed(1)torch.cuda.manual_seed(1)np.random.seed(1)random.seed(1)env = JsspN5(n_job=n_j, n_mch=n_m, low=l, high=h, reward_type=reward_type)batch_data = BatchGraph()# instances = np.load('../test_data/syn{}x{}.npy'.format(n_j, n_m))instances = np.load('../validation_data/validation_instance_{}x{}[{},{}].npy'.format(n_j, n_m, l, h))# instances = np.array([uni_instance_gen(n_j=n_j, n_m=n_m, low=l, high=h) for _ in range(b_size)])states, feasible_as, dones = env.reset(instances=instances, init_type=init_type, device=dev)# print(env.incumbent_objs)# print(feasible_as)actor = Actor(3, hid_dim, embedding_l=4, policy_l=4, embedding_type='gin+dghan', heads=1, dropout=0.0).to(dev)while env.itr < 500:batch_data.wrapper(*states)actions, log_ps = actor(batch_data, feasible_as)states, rewards, feasible_as, dones = env.step(actions, dev)# print(actions)# print(env.incumbent_objs)# print(feasible_as)# grad = torch.autograd.grad(log_ps.sum(), [param for param in actor.parameters()])# print(env.incumbent_objs)# print(np.load('./validation_data/validation{}x{}_ortools_result.npy'.format(n_j, n_m)))optimal = np.load('../validation_data/validation{}x{}_ortools_result.npy'.format(n_j, n_m))gap = ((env.current_objs.view(-1).cpu().numpy() - optimal)/optimal).mean()print(gap)