面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现

devtools/2025/1/16 13:11:24/

循环状态空间模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出。该模型在现代基于模型的强化学习(Model-Based Reinforcement Learning, MBRL)中发挥着关键作用,其主要目标是构建可靠的环境动态预测模型。通过这些学习得到的模型,智能体能够模拟未来轨迹并进行前瞻性的行为规划。

下面我们就来用一个实际案例来介绍RSSM。

环境配置

环境配置是实现过程中的首要步骤。我们这里用易于使用的 Gym API。为了提高实现效率,设计了多个模块化的包装器(wrapper),用于初始化参数并将观察结果调整为指定格式。

InitialWrapper 的设计允许在不执行任何动作的情况下进行特定数量的观察,同时支持在返回观察结果之前多次重复同一动作。这种设计对于响应具有显著延迟特性的环境特别有效。

PreprocessFrame 包装器负责将观察结果转换为正确的数据类型(本文中使用 numpy 数组),并支持灰度转换功能。

 classInitialWrapper(gym.Wrapper):  def__init__(self, env: gym.Env, no_ops: int=0, repeat: int=1):  super(InitialWrapper, self).__init__(env)  self.repeat=repeat  self.no_ops=no_ops  self.op_counter=0  defstep(self, action: ActType) ->Tuple[ObsType, float, bool, bool, dict]:  ifself.op_counter<self.no_ops:  obs, reward, done, info=self.env.step(0)  self.op_counter+=1  total_reward=0.0  done=False  for_inrange(self.repeat):  obs, reward, done, info=self.env.step(action)  total_reward+=reward  ifdone:  break  returnobs, total_reward, done, info  classPreprocessFrame(gym.ObservationWrapper):  def__init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool=False):  super(PreprocessFrame, self).__init__(env)  self.shape=new_shape  self.observation_space=gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32)  self.grayscale=grayscale  ifself.grayscale:  self.observation_space=gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32)  defobservation(self, obs: torch.Tensor) ->torch.Tensor:  obs=obs.astype(np.uint8)  new_frame=cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA)  ifself.grayscale:  new_frame=cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY)  new_frame=np.expand_dims(new_frame, -1)  torch_frame=torch.from_numpy(new_frame).float()  torch_frame=torch_frame/255.0  returntorch_frame  defmake_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool=True, **kwargs):  env=gym.make(env_name, **kwargs)  env=PreprocessFrame(env, new_shape, grayscale=grayscale)  returnenv

make_env 函数用于创建一个具有指定配置参数的环境实例。

模型架构

RSSM 的实现依赖于多个关键模型组件。具体来说,需要实现以下四个核心模块:

  • 原始观察编码器(Encoder)
  • 动态模型(Dynamics Model):通过确定性状态 h 和随机状态 s 对编码观察的时间依赖性进行建模
  • 解码器(Decoder):将随机状态和确定性状态映射回原始观察空间
  • 奖励模型(Reward Model):将随机状态和确定性状态映射到奖励值

RSSM 模型组件结构图。模型包含随机状态 s 和确定性状态 h。

编码器实现

编码器采用简单的卷积神经网络(CNN)结构,将输入图像降维到一维嵌入表示。实现中使用了 BatchNorm 来提升训练稳定性。

 classEncoderCNN(nn.Module):  def__init__(self, in_channels: int, embedding_dim: int=2048, input_shape: Tuple[int, int] = (128, 128)):  super(EncoderCNN, self).__init__()  # 定义卷积层结构self.conv1=nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)  self.conv2=nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  self.conv3=nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  self.conv4=nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  self.fc1=nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim)  # 批标准化层self.bn1=nn.BatchNorm2d(32)  self.bn2=nn.BatchNorm2d(64)  self.bn3=nn.BatchNorm2d(128)  self.bn4=nn.BatchNorm2d(256)  def_compute_conv_output(self, shape: Tuple[int, int, int]):  withtorch.no_grad():  x=torch.randn(1, shape[0], shape[1], shape[2])  x=self.conv1(x)  x=self.conv2(x)  x=self.conv3(x)  x=self.conv4(x)  returnx.shape[1] *x.shape[2] *x.shape[3]  defforward(self, x):  x=torch.relu(self.conv1(x))  x=self.bn1(x)  x=torch.relu(self.conv2(x))  x=self.bn2(x)  x=torch.relu(self.conv3(x))  x=self.bn3(x)  x=self.conv4(x)  x=self.bn4(x)  x=x.view(x.size(0), -1)  x=self.fc1(x)  returnx

解码器实现

解码器遵循传统自编码器架构设计,其功能是将编码后的观察结果重建回原始观察空间。

 classDecoderCNN(nn.Module):  def__init__(self, hidden_size: int, state_size: int,  embedding_size: int,  use_bn: bool=True, output_shape: Tuple[int, int] = (3, 128, 128)):  super(DecoderCNN, self).__init__()  self.output_shape=output_shape  self.embedding_size=embedding_size  # 全连接层进行特征变换self.fc1=nn.Linear(hidden_size+state_size, embedding_size)  self.fc2=nn.Linear(embedding_size, 256* (output_shape[1] //16) * (output_shape[2] //16))  # 反卷积层进行上采样self.conv1=nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  self.conv2=nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  self.conv3=nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  self.conv4=nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)  # 批标准化层self.bn1=nn.BatchNorm2d(128)  self.bn2=nn.BatchNorm2d(64)  self.bn3=nn.BatchNorm2d(32)  self.use_bn=use_bn  defforward(self, h: torch.Tensor, s: torch.Tensor):  x=torch.cat([h, s], dim=-1)  x=self.fc1(x)  x=torch.relu(x)  x=self.fc2(x)  x=x.view(-1, 256, self.output_shape[1] //16, self.output_shape[2] //16)  ifself.use_bn:  x=torch.relu(self.bn1(self.conv1(x)))  x=torch.relu(self.bn2(self.conv2(x)))  x=torch.relu(self.bn3(self.conv3(x)))  else:  x=torch.relu(self.conv1(x))  x=torch.relu(self.conv2(x))  x=torch.relu(self.conv3(x))  x=self.conv4(x)  returnx    

奖励模型实现

奖励模型采用了一个三层前馈神经网络结构,用于将随机状态 s 和确定性状态 h 映射到正态分布参数,进而通过采样获得奖励预测。

 classRewardModel(nn.Module):  def__init__(self, hidden_dim: int, state_dim: int):  super(RewardModel, self).__init__()  self.fc1=nn.Linear(hidden_dim+state_dim, hidden_dim)  self.fc2=nn.Linear(hidden_dim, hidden_dim)  self.fc3=nn.Linear(hidden_dim, 2)  defforward(self, h: torch.Tensor, s: torch.Tensor):  x=torch.cat([h, s], dim=-1)  x=torch.relu(self.fc1(x))  x=torch.relu(self.fc2(x))  x=self.fc3(x)  returnx

动态模型的实现

动态模型是 RSSM 架构中最复杂的组件,需要同时处理先验和后验状态转移模型:

  1. 后验转移模型:在能够访问真实观察的情况下使用(主要在训练阶段),用于在给定观察和历史状态的条件下近似随机状态的后验分布。
  2. 先验转移模型:用于近似先验状态分布,仅依赖于前一时刻状态,不依赖于观察。这在无法获取后验观察的推理阶段使用。

这两个模型均通过单层前馈网络进行参数化,输出各自正态分布的均值和对数方差,用于状态 s 的采样。该实现采用了简单的网络结构,但可以根据需要扩展为更复杂的架构。

确定性状态采用门控循环单元(GRU)实现。其输入包括:

  • 前一时刻的隐藏状态
  • 独热编码动作
  • 前一时刻随机状态 s(根据是否可以获取观察来选择使用后验或先验状态)

这些输入信息足以让模型了解动作历史和系统状态。以下是具体实现代码:

 classDynamicsModel(nn.Module):  def__init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int=1):  super(DynamicsModel, self).__init__()  self.hidden_dim=hidden_dim  # 递归层实现,支持多层 GRUself.rnn=nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for_inrange(rnn_layer)])  # 状态动作投影层self.project_state_action=nn.Linear(action_dim+state_dim, hidden_dim)  # 先验网络:输出正态分布参数self.prior=nn.Linear(hidden_dim, state_dim*2)  self.project_hidden_action=nn.Linear(hidden_dim+action_dim, hidden_dim)  # 后验网络:输出正态分布参数self.posterior=nn.Linear(hidden_dim, state_dim*2)  self.project_hidden_obs=nn.Linear(hidden_dim+embedding_dim, hidden_dim)  self.state_dim=state_dim  self.act_fn=nn.ReLU()  defforward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor,  obs: torch.Tensor=None, dones: torch.Tensor=None):  """  动态模型的前向传播参数:  prev_hidden: RNN的前一隐藏状态,形状 (batch_size, hidden_dim)  prev_state: 前一随机状态,形状 (batch_size, state_dim)  actions: 独热编码动作序列,形状 (sequence_length, batch_size, action_dim)  obs: 编码器输出的观察嵌入,形状 (sequence_length, batch_size, embedding_dim)  dones: 终止状态标志"""  B, T, _=actions.size()  # 用于无观察访问时的推理# 初始化存储列表hiddens_list= []  posterior_means_list= []  posterior_logvars_list= []  prior_means_list= []  prior_logvars_list= []  prior_states_list= []  posterior_states_list= []  # 存储初始状态hiddens_list.append(prev_hidden.unsqueeze(1))    prior_states_list.append(prev_state.unsqueeze(1))  posterior_states_list.append(prev_state.unsqueeze(1))  # 时序展开fortinrange(T-1):  # 提取当前时刻状态和动作action_t=actions[:, t, :]  obs_t=obs[:, t, :] ifobsisnotNoneelsetorch.zeros(B, self.embedding_dim, device=actions.device)  state_t=posterior_states_list[-1][:, 0, :] ifobsisnotNoneelseprior_states_list[-1][:, 0, :]  state_t=state_tifdonesisNoneelsestate_t* (1-dones[:, t, :])  hidden_t=hiddens_list[-1][:, 0, :]  # 状态动作组合state_action=torch.cat([state_t, action_t], dim=-1)  state_action=self.act_fn(self.project_state_action(state_action))  # RNN 状态更新foriinrange(len(self.rnn)):  hidden_t=self.rnn[i](state_action, hidden_t)  # 先验分布计算hidden_action=torch.cat([hidden_t, action_t], dim=-1)  hidden_action=self.act_fn(self.project_hidden_action(hidden_action))  prior_params=self.prior(hidden_action)  prior_mean, prior_logvar=torch.chunk(prior_params, 2, dim=-1)  # 从先验分布采样prior_dist=torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))  prior_state_t=prior_dist.rsample()  # 后验分布计算ifobsisNone:  posterior_mean=prior_mean  posterior_logvar=prior_logvar  else:  hidden_obs=torch.cat([hidden_t, obs_t], dim=-1)  hidden_obs=self.act_fn(self.project_hidden_obs(hidden_obs))  posterior_params=self.posterior(hidden_obs)  posterior_mean, posterior_logvar=torch.chunk(posterior_params, 2, dim=-1)  # 从后验分布采样posterior_dist=torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))  posterior_state_t=posterior_dist.rsample()  # 保存状态posterior_means_list.append(posterior_mean.unsqueeze(1))  posterior_logvars_list.append(posterior_logvar.unsqueeze(1))  prior_means_list.append(prior_mean.unsqueeze(1))  prior_logvars_list.append(prior_logvar.unsqueeze(1))  prior_states_list.append(prior_state_t.unsqueeze(1))  posterior_states_list.append(posterior_state_t.unsqueeze(1))  hiddens_list.append(hidden_t.unsqueeze(1))  # 合并时序数据hiddens=torch.cat(hiddens_list, dim=1)  prior_states=torch.cat(prior_states_list, dim=1)  posterior_states=torch.cat(posterior_states_list, dim=1)  prior_means=torch.cat(prior_means_list, dim=1)  prior_logvars=torch.cat(prior_logvars_list, dim=1)  posterior_means=torch.cat(posterior_means_list, dim=1)  posterior_logvars=torch.cat(posterior_logvars_list, dim=1)  returnhiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars

需要特别注意的是,这里的观察输入并非原始观察数据,而是经过编码器处理后的嵌入表示。这种设计能够有效降低计算复杂度并提升模型的泛化能力。

RSSM 整体架构

将前述组件整合为完整的 RSSM 模型。其核心是

generate_rollout

方法,负责调用动态模型并生成环境动态的潜在表示序列。对于没有历史潜在状态的情况(通常发生在轨迹开始时),该方法会进行必要的初始化。下面是完整的实现代码:

 classRSSM:  def__init__(self,  encoder: EncoderCNN,  decoder: DecoderCNN,  reward_model: RewardModel,  dynamics_model: nn.Module,  hidden_dim: int,  state_dim: int,  action_dim: int,  embedding_dim: int,  device: str="mps"):  """  循环状态空间模型(RSSM)实现参数:encoder: 确定性状态编码器decoder: 观察重构解码器reward_model: 奖励预测模型dynamics_model: 状态动态模型hidden_dim: RNN 隐藏层维度state_dim: 随机状态维度action_dim: 动作空间维度embedding_dim: 观察嵌入维度device: 计算设备"""  super(RSSM, self).__init__()  # 模型组件初始化self.dynamics=dynamics_model  self.encoder=encoder  self.decoder=decoder  self.reward_model=reward_model  # 维度参数存储self.hidden_dim=hidden_dim  self.state_dim=state_dim  self.action_dim=action_dim  self.embedding_dim=embedding_dim  # 模型迁移至指定设备self.dynamics.to(device)  self.encoder.to(device)  self.decoder.to(device)  self.reward_model.to(device)  defgenerate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor=None, states: torch.Tensor=None,  obs: torch.Tensor=None, dones: torch.Tensor=None):  """生成状态序列展开参数:actions: 动作序列hiddens: 初始隐藏状态(可选)states: 初始随机状态(可选)obs: 观察序列(可选)dones: 终止标志序列返回:完整的状态展开序列"""# 状态初始化ifhiddensisNone:  hiddens=torch.zeros(actions.size(0), self.hidden_dim).to(actions.device)  ifstatesisNone:  states=torch.zeros(actions.size(0), self.state_dim).to(actions.device)  # 执行动态模型展开dynamics_result=self.dynamics(hiddens, states, actions, obs, dones)  hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars=dynamics_result  returnhiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars  deftrain(self):  """启用训练模式"""self.dynamics.train()  self.encoder.train()  self.decoder.train()  self.reward_model.train()  defeval(self):  """启用评估模式"""self.dynamics.eval()  self.encoder.eval()  self.decoder.eval()  self.reward_model.eval()  defencode(self, obs: torch.Tensor):  """观察编码"""returnself.encoder(obs)  defdecode(self, state: torch.Tensor):  """状态解码为观察"""returnself.decoder(state)  defpredict_reward(self, h: torch.Tensor, s: torch.Tensor):  """奖励预测"""returnself.reward_model(h, s)  defparameters(self):  """返回所有可训练参数"""returnlist(self.dynamics.parameters()) +list(self.encoder.parameters()) + \list(self.decoder.parameters()) +list(self.reward_model.parameters())  defsave(self, path: str):  """模型状态保存"""torch.save({  "dynamics": self.dynamics.state_dict(),  "encoder": self.encoder.state_dict(),  "decoder": self.decoder.state_dict(),  "reward_model": self.reward_model.state_dict()  }, path)  defload(self, path: str):  """模型状态加载"""checkpoint=torch.load(path)  self.dynamics.load_state_dict(checkpoint["dynamics"])  self.encoder.load_state_dict(checkpoint["encoder"])  self.decoder.load_state_dict(checkpoint["decoder"])  self.reward_model.load_state_dict(checkpoint["reward_model"])

这个实现提供了一个完整的 RSSM 框架,包含了模型的训练、评估、状态保存和加载等基本功能。该框架可以作为基础结构,根据具体应用场景进行扩展和优化。

训练系统设计

RSSM 的训练系统主要包含两个核心组件:经验回放缓冲区(Experience Replay Buffer)和智能体(Agent)。其中,缓冲区负责存储历史经验数据用于训练,而智能体则作为环境与 RSSM 之间的接口,实现数据收集策略。

经验回放缓冲区实现

缓冲区采用循环队列结构,用于存储和管理观察、动作、奖励和终止状态等数据。通过

sample

方法可以随机采样训练序列。

 classBuffer:  def__init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device):  """经验回放缓冲区初始化参数:buffer_size: 缓冲区容量obs_shape: 观察数据维度action_shape: 动作数据维度device: 计算设备"""self.buffer_size=buffer_size  self.obs_buffer=np.zeros((buffer_size, *obs_shape), dtype=np.float32)  self.action_buffer=np.zeros((buffer_size, *action_shape), dtype=np.int32)  self.reward_buffer=np.zeros((buffer_size, 1), dtype=np.float32)  self.done_buffer=np.zeros((buffer_size, 1), dtype=np.bool_)  self.device=device  self.idx=0  defadd(self, obs: torch.Tensor, action: int, reward: float, done: bool):  """添加单步经验数据"""self.obs_buffer[self.idx] =obs  self.action_buffer[self.idx] =action  self.reward_buffer[self.idx] =reward  self.done_buffer[self.idx] =done  self.idx= (self.idx+1) %self.buffer_size  defsample(self, batch_size: int, sequence_length: int):  """随机采样经验序列参数:batch_size: 批量大小sequence_length: 序列长度返回:经验数据元组 (observations, actions, rewards, dones)"""# 随机选择序列起始位置starting_idxs=np.random.randint(0, (self.idx%self.buffer_size) -sequence_length, (batch_size,))  # 构建完整序列索引index_tensor=np.stack([np.arange(start, start+sequence_length) forstartinstarting_idxs])  # 提取数据序列obs_sequence=self.obs_buffer[index_tensor]  action_sequence=self.action_buffer[index_tensor]  reward_sequence=self.reward_buffer[index_tensor]  done_sequence=self.done_buffer[index_tensor]  returnobs_sequence, action_sequence, reward_sequence, done_sequence  defsave(self, path: str):  """保存缓冲区数据"""np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer,  reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx)  defload(self, path: str):  """加载缓冲区数据"""data=np.load(path)  self.obs_buffer=data["obs_buffer"]  self.action_buffer=data["action_buffer"]  self.reward_buffer=data["reward_buffer"]  self.done_buffer=data["done_buffer"]  self.idx=data["idx"]

智能体设计

智能体实现了数据收集和规划功能。当前实现采用了简单的随机策略进行数据收集,但该框架支持扩展更复杂的策略。

 classPolicy(ABC):  """策略基类"""@abstractmethod  def__call__(self, obs):  pass  classRandomPolicy(Policy):  """随机采样策略"""def__init__(self, env: Env):  self.env=env  def__call__(self, obs):  returnself.env.action_space.sample()  classAgent:  def__init__(self, env: Env, rssm: RSSM, buffer_size: int=100000, collection_policy: str="random", device="mps"):  """智能体初始化参数:env: 环境实例rssm: RSSM模型实例buffer_size: 经验缓冲区大小collection_policy: 数据收集策略类型device: 计算设备"""self.env=env  # 策略选择matchcollection_policy:  case"random":  self.rollout_policy=RandomPolicy(env)  case_:  raiseValueError("Invalid rollout policy")  self.buffer=Buffer(buffer_size, env.observation_space.shape, env.action_space.shape, device=device)  self.rssm=rssm  defdata_collection_action(self, obs):  """执行数据收集动作"""returnself.rollout_policy(obs)  defcollect_data(self, num_steps: int):  """收集训练数据参数:num_steps: 收集步数"""obs=self.env.reset()  done=False  iterator=tqdm(range(num_steps), desc="Data Collection")  for_initerator:  action=self.data_collection_action(obs)  next_obs, reward, done, _, _=self.env.step(action)  self.buffer.add(next_obs, action, reward, done)  obs=next_obs  ifdone:  obs=self.env.reset()  defimagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor):  """执行想象展开参数:prev_hidden: 前一隐藏状态prev_state: 前一随机状态actions: 动作序列返回:完整的模型输出,包括隐藏状态、先验状态、后验状态等"""hiddens, prior_states, posterior_states, prior_means, prior_logvars, \posterior_means, posterior_logvars=self.rssm.generate_rollout(actions, prev_hidden, prev_state)  # 在想象阶段使用先验状态预测奖励rewards=self.rssm.predict_reward(hiddens, prior_states)  returnhiddens, prior_states, posterior_states, prior_means, \prior_logvars, posterior_means, posterior_logvars, rewards  defplan(self, num_steps: int, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor):  """执行规划参数:num_steps: 规划步数prev_hidden: 初始隐藏状态prev_state: 初始随机状态actions: 动作序列返回:规划得到的隐藏状态和先验状态序列"""hidden_states= []  prior_states= []  hiddens=prev_hidden  states=prev_state  for_inrange(num_steps):  hiddens, states, _, _, _, _, _, _=self.imagine_rollout(hiddens, states, actions)  hidden_states.append(hiddens)  prior_states.append(states)  hidden_states=torch.stack(hidden_states)  prior_states=torch.stack(prior_states)  returnhidden_states, prior_states

这部分实现提供了完整的数据管理和智能体交互框架。通过经验回放缓冲区,可以高效地存储和重用历史数据;通过智能体的抽象策略接口,可以方便地扩展不同的数据收集策略。同时智能体还实现了基于模型的想象展开和规划功能,为后续的决策制定提供了基础。

训练器实现与实验

训练器设计

训练器是 RSSM 实现中的最后一个关键组件,负责协调模型训练过程。训练器接收 RSSM 模型、智能体、优化器等组件,并实现具体的训练逻辑。

 logging.basicConfig(  level=logging.INFO,  format="%(asctime)s - %(levelname)s - %(message)s",  handlers=[  logging.StreamHandler(),  # 控制台输出logging.FileHandler("training.log", mode="w")  # 文件输出]  )  logger=logging.getLogger(__name__)  classTrainer:  def__init__(self, rssm: RSSM, agent: Agent, optimizer: torch.optim.Optimizer, device: torch.device):  """训练器初始化参数:rssm: RSSM 模型实例agent: 智能体实例optimizer: 优化器实例device: 计算设备"""self.rssm=rssm  self.optimizer=optimizer  self.device=device  self.agent=agent  self.writer=SummaryWriter()  # tensorboard 日志记录器deftrain_batch(self, batch_size: int, seq_len: int, iteration: int, save_images: bool=False):  """单批次训练参数:batch_size: 批量大小seq_len: 序列长度iteration: 当前迭代次数save_images: 是否保存重建图像"""# 采样训练数据obs, actions, rewards, dones=self.agent.buffer.sample(batch_size, seq_len)  # 数据预处理actions=torch.tensor(actions).long().to(self.device)  actions=F.one_hot(actions, self.rssm.action_dim).float()  obs=torch.tensor(obs, requires_grad=True).float().to(self.device)  rewards=torch.tensor(rewards, requires_grad=True).float().to(self.device)  dones=torch.tensor(dones).float().to(self.device)  # 观察编码encoded_obs=self.rssm.encoder(obs.reshape(-1, *obs.shape[2:]).permute(0, 3, 1, 2))  encoded_obs=encoded_obs.reshape(batch_size, seq_len, -1)  # 执行 RSSM 展开rollout=self.rssm.generate_rollout(actions, obs=encoded_obs, dones=dones)  hiddens, prior_states, posterior_states, prior_means, prior_logvars, \posterior_means, posterior_logvars=rollout  # 重构观察hiddens_reshaped=hiddens.reshape(batch_size*seq_len, -1)  posterior_states_reshaped=posterior_states.reshape(batch_size*seq_len, -1)  decoded_obs=self.rssm.decoder(hiddens_reshaped, posterior_states_reshaped)  decoded_obs=decoded_obs.reshape(batch_size, seq_len, *obs.shape[-3:])  # 奖励预测reward_params=self.rssm.reward_model(hiddens, posterior_states)  mean, logvar=torch.chunk(reward_params, 2, dim=-1)  logvar=F.softplus(logvar)  reward_dist=Normal(mean, torch.exp(logvar))  predicted_rewards=reward_dist.rsample()  # 可视化ifsave_images:  batch_idx=np.random.randint(0, batch_size)  seq_idx=np.random.randint(0, seq_len-3)  fig=self._visualize(obs, decoded_obs, rewards, predicted_rewards, batch_idx, seq_idx, iteration, grayscale=True)  ifnotos.path.exists("reconstructions"):  os.makedirs("reconstructions")  fig.savefig(f"reconstructions/iteration_{iteration}.png")  self.writer.add_figure("Reconstructions", fig, iteration)  plt.close(fig)  # 计算损失reconstruction_loss=self._reconstruction_loss(decoded_obs, obs)  kl_loss=self._kl_loss(prior_means, F.softplus(prior_logvars), posterior_means, F.softplus(posterior_logvars))  reward_loss=self._reward_loss(rewards, predicted_rewards)  loss=reconstruction_loss+kl_loss+reward_loss  # 反向传播和优化self.optimizer.zero_grad()  loss.backward()  nn.utils.clip_grad_norm_(self.rssm.parameters(), 1, norm_type=2)  self.optimizer.step()  returnloss.item(), reconstruction_loss.item(), kl_loss.item(), reward_loss.item()  deftrain(self, iterations: int, batch_size: int, seq_len: int):  """执行完整训练过程参数:iterations: 迭代总次数batch_size: 批量大小seq_len: 序列长度"""self.rssm.train()  iterator=tqdm(range(iterations), desc="Training", total=iterations)  losses= []  infos= []  last_loss=float("inf")  foriiniterator:  # 执行单批次训练loss, reconstruction_loss, kl_loss, reward_loss=self.train_batch(batch_size, seq_len, i, save_images=i%100==0)  # 记录训练指标self.writer.add_scalar("Loss", loss, i)  self.writer.add_scalar("Reconstruction Loss", reconstruction_loss, i)  self.writer.add_scalar("KL Loss", kl_loss, i)  self.writer.add_scalar("Reward Loss", reward_loss, i)  # 保存最佳模型ifloss<last_loss:  self.rssm.save("rssm.pth")  last_loss=loss  # 记录详细信息info= {  "Loss": loss,  "Reconstruction Loss": reconstruction_loss,  "KL Loss": kl_loss,  "Reward Loss": reward_loss  }  losses.append(loss)  infos.append(info)  # 定期输出训练状态ifi%10==0:  logger.info("\n----------------------------")  logger.info(f"Iteration: {i}")  logger.info(f"Loss: {loss:.4f}")  logger.info(f"Running average last 20 losses: {sum(losses[-20:]) /20: .4f}")  logger.info(f"Reconstruction Loss: {reconstruction_loss:.4f}")  logger.info(f"KL Loss: {kl_loss:.4f}")  logger.info(f"Reward Loss: {reward_loss:.4f}")### 实验示例以下是一个在CarRacing环境中训练RSSM的完整示例:```python# 环境初始化env=make_env("CarRacing-v2", render_mode="rgb_array", continuous=False, grayscale=True)  # 模型参数设置hidden_size=1024  embedding_dim=1024  state_dim=512  # 模型组件实例化encoder=EncoderCNN(in_channels=1, embedding_dim=embedding_dim)  decoder=DecoderCNN(hidden_size=hidden_size, state_size=state_dim, embedding_size=embedding_dim, output_shape=(1,128,128))  reward_model=RewardModel(hidden_dim=hidden_size, state_dim=state_dim)  dynamics_model=DynamicsModel(hidden_dim=hidden_size, state_dim=state_dim, action_dim=5, embedding_dim=embedding_dim)  # RSSM 模型构建rssm=RSSM(dynamics_model=dynamics_model,  encoder=encoder,  decoder=decoder,  reward_model=reward_model,  hidden_dim=hidden_size,  state_dim=state_dim,  action_dim=5,  embedding_dim=embedding_dim)  # 训练设置optimizer=torch.optim.Adam(rssm.parameters(), lr=1e-3)  agent=Agent(env, rssm)  trainer=Trainer(rssm, agent, optimizer=optimizer, device="cuda")  # 数据收集和训练trainer.collect_data(20000)  # 收集 20000 步经验数据trainer.save_buffer("buffer.npz")  # 保存经验缓冲区trainer.train(10000, 32, 20)  # 执行 10000 次迭代训练

总结

本文详细介绍了基于 PyTorch 实现 RSSM 的完整过程。RSSM 的架构相比传统的 VAE 或 RNN 更为复杂,这主要源于其混合了随机和确定性状态的特性。通过手动实现这一架构,我们可以深入理解其背后的理论基础及其强大之处。RSSM 能够递归地生成未来潜在状态轨迹,这为智能体的行为规划提供了基础。

实现的优点在于其计算负载适中,可以在单个消费级 GPU 上进行训练,在有充足时间的情况下甚至可以在 CPU 上运行。这一工作基于论文《Learning Latent Dynamics for Planning from Pixels》,该论文为 RSSM 类动态模型奠定了基础。后续的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》进一步发展了这一架构。这些改进的架构将在未来的研究中深入探讨,因为它们对理解 MBRL 方法提供了重要的见解。

https://avoid.overfit.cn/post/8d8412f5ef6544e4ba097547a38830ac

作者:Lukas Bierling


http://www.ppmy.cn/devtools/150301.html

相关文章

Matlab APP Designer

我想给聚类的代码加一个图形化界面&#xff0c;需要输入一些数据和一些参数并输出聚类后的图像和一些评价指标的值。 gpt说 可以用 app designer 界面元素设计 在 设计视图 中直接拖动即可 如图1&#xff0c;我拖进去一个 按钮 &#xff0c;图2 红色部分 出现一行 Button 图…

贪心算法汇总

1.贪心算法 贪心的本质是选择每一阶段的局部最优&#xff0c;从而达到全局最优。 如何能看出局部最优是否能推出整体最优 靠自己手动模拟&#xff0c;如果模拟可行&#xff0c;就可以试一试贪心策略&#xff0c;如果不可行&#xff0c;可能需要动态规划。 如何验证可不可以…

CentOS下安装Docker

Docker 必须要在Linux环境下才能运行&#xff0c;windows下运行也是安装虚拟机后才能下载安装运行&#xff0c;菜鸟教程 下载安装 linux 依次执行下边步骤 更新 yum yum update 卸载旧的Docker yum remove docker docker-client docker-client-latest docker-common doc…

(STM32笔记)十二、DMA的基础知识与用法 第三部分

我用的是正点的STM32F103来进行学习&#xff0c;板子和教程是野火的指南者。 之后的这个系列笔记开头未标明的话&#xff0c;用的也是这个板子和教程。 DMA的基础知识与用法 三、DMA程序验证1、DMA 存储器到存储器模式实验&#xff08;1&#xff09;DMA结构体解释&#xff08;2…

QT c++ 样式 设置 按钮(QPushButton)的渐变色美化

上一篇文章中描述了标签的渐变色美化,本文描述按钮的渐变色美化。 1.头文件 #ifndef WIDGET_H #define WIDGET_H #include <QWidget> //#include "CustomButton.h"#include <QVBoxLayout> #include <QLinearGradient> #include <QPushButton&…

Maven在不同操作系统上如何安装?

大家好&#xff0c;我是袁庭新。Maven是一个重要的工具&#xff0c;还有很多初学者竟然不知道如何安装Maven&#xff1f;这篇文章将系统介绍如何在Windows、macOS、Linux操作系统上安装Maven。 Maven是一个基于Java的项目管理工具。因此&#xff0c;最基本的要求是在计算机上安…

uniapp使用sm4加密

安装&#xff1a;npm install sm-crypto --save 1、在utils下新建crypto.js文件 // sm4 加密 export function encryption(params) {const SM4 require("sm-crypto").sm4const key 0123456789abcdeffedcba9876543212; // 提供的密钥const iv fedcba9876543210012…

Android studio gradle与gradle插件

最终换gradle版本&#xff0c;糊成一坨。 记录一下 Android studio里有两个容易搞混&#xff0c;记载一下。 build.gradle文件中的为插件版本&#xff1a; classpath "com.android.tools.build:gradle:3.5.0" gradle.properties里的才是gradle版本。 distributio…