PyTorch 深度学习实战(19):离线强化学习与 Conservative Q-Learning (CQL) 算法

devtools/2025/3/25 23:10:22/

在上一篇文章中,我们探讨了分布式强化学习与 IMPALA 算法,展示了如何通过并行化训练提升强化学习的效率。本文将聚焦 离线强化学习(Offline RL) 这一新兴方向,并实现 Conservative Q-Learning (CQL) 算法,利用 Minari 提供的静态数据集训练安全的强化学习策略。


一、离线强化学习与 CQL 原理

1. 离线强化学习的特点
  • 无需环境交互:直接从预收集的静态数据集学习

  • 数据效率高:复用历史经验(如人类演示、日志数据)

  • 安全风险低:避免在线探索中的危险行为

2. CQL 核心思想

CQL 通过保守策略评估防止价值函数高估,其目标函数为:

3. 算法优势
  • 防止分布偏移导致的策略退化

  • 支持混合质量数据集(专家数据 + 随机数据)

  • 适用于真实世界场景(如医疗、金融)


二、CQL 实现步骤(基于 Minari 数据集)

我们将使用 Minari 库中的 D4RL/door/human-v2 数据集训练策略:

  1. 安装 Minari 并加载数据集

  2. 定义保守 Q 网络

  3. 实现保守正则化损失

  4. 策略优化与评估


三、代码实现

以下是 CQL 算法的完整实现代码:

import torch
import minari
import numpy as np
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from collections import deque
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
​
# 1. 增强型配置类(带维度校验)
class SafeConfig:# 训练参数batch_size = 1024lr = 3e-5tau = 0.007gamma = 0.99total_epochs = 500# 网络架构hidden_dim = 768num_layers = 3dropout_rate = 0.1activation_fn = 'Mish'  # 支持Mish/SiLU/ReLU# 正则化参数conservative_init = 2.5conservative_decay = 0.995min_conservative = 0.3reward_scale = 4.0# 探索参数noise_scale = 0.2noise_clip = 0.5candidate_samples = 400imitation_ratio = 0.15
​
# 2. 安全数据加载系统
class SafeDataset(Dataset):def __init__(self, dataset_name):# 加载原始数据dataset = minari.load_dataset(dataset_name, download=True)# 获取维度信息first_ep = dataset[0]self.state_dim = first_ep.observations[0].shape[0]self.action_dim = first_ep.actions[0].shape[0]# 数据存储self.obs, self.acts, self.rews, self.dones, self.next_obs = [], [], [], [], []for ep in dataset:self._store_episode(ep.observations[:-1],ep.actions,ep.rewards,np.logical_or(ep.terminations, ep.truncations),ep.observations[1:])# 标准化self._normalize()self.priorities = np.ones(len(self.obs)) * 1e-5def _store_episode(self, obs, acts, rews, dones, next_obs):self.obs.extend(obs)self.acts.extend(acts)self.rews.extend(rews)self.dones.extend(dones)self.next_obs.extend(next_obs)def _normalize(self):# 状态标准化self.obs_mean = np.mean(self.obs, axis=0)self.obs_std = np.std(self.obs, axis=0) + 1e-8self.obs = (self.obs - self.obs_mean) / self.obs_stdself.next_obs = (self.next_obs - self.obs_mean) / self.obs_std# 动作标准化self.act_mean = np.mean(self.acts, axis=0)self.act_std = np.std(self.acts, axis=0) + 1e-8self.acts = (self.acts - self.act_mean) / self.act_stddef update_priorities(self, indices, priorities):self.priorities[indices] = np.abs(priorities.flatten()) + 1e-5def __len__(self):return len(self.obs)def __getitem__(self, idx):return (idx,torch.FloatTensor(self.obs[idx]),torch.FloatTensor(self.acts[idx]),torch.FloatTensor(self.next_obs[idx]),torch.FloatTensor([self.rews[idx]]),torch.FloatTensor([bool(self.dones[idx])]))
​
# 3. 维度安全网络架构
class SafeQNetwork(torch.nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.state_dim = state_dimself.action_dim = action_dimself.input_dim = state_dim + action_dim  # 关键动态计算# 主网络self.feature_net = self._build_network()self.q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)self.q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)# 目标网络self.target_net = self._build_network()self.target_q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)self.target_q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)# 初始化self._init_weights()self._update_target(1.0)def _build_network(self):layers = []input_dim = self.input_dim  # 使用动态计算值for _ in range(SafeConfig.num_layers):layers.extend([torch.nn.Linear(input_dim, SafeConfig.hidden_dim),torch.nn.LayerNorm(SafeConfig.hidden_dim),self._activation(),torch.nn.Dropout(SafeConfig.dropout_rate),])input_dim = SafeConfig.hidden_dimreturn torch.nn.Sequential(*layers)def _activation(self):return {'Mish': torch.nn.Mish(),'SiLU': torch.nn.SiLU(),'ReLU': torch.nn.ReLU()}[SafeConfig.activation_fn]def _init_weights(self):for m in self.modules():if isinstance(m, torch.nn.Linear):torch.nn.init.orthogonal_(m.weight)torch.nn.init.normal_(m.bias, 0, 0.1)def forward(self, state, action):# 维度校验assert state.shape[-1] == self.state_dim, f"State dim error: {state.shape[-1]} vs {self.state_dim}"assert action.shape[-1] == self.action_dim, f"Action dim error: {action.shape[-1]} vs {self.action_dim}"x = torch.cat([state, action], dim=1)features = self.feature_net(x)return self.q1(features), self.q2(features)def target_forward(self, state, action):x = torch.cat([state, action], dim=1)features = self.target_net(x)return self.target_q1(features), self.target_q2(features)def _update_target(self, tau):with torch.no_grad():for t_param, param in zip(self.target_net.parameters(), self.feature_net.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)for t_param, param in zip(self.target_q1.parameters(), self.q1.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)for t_param, param in zip(self.target_q2.parameters(), self.q2.parameters()):t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)
​
# 4. 安全训练系统
class SafeTrainer:def __init__(self, dataset_name):# 数据系统self.dataset = SafeDataset(dataset_name)self.state_dim = self.dataset.state_dimself.action_dim = self.dataset.action_dim# 网络系统self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.q_net = SafeQNetwork(self.state_dim, self.action_dim).to(self.device)# 优化系统self.optimizer = torch.optim.AdamW(self.q_net.parameters(),lr=SafeConfig.lr,weight_decay=1e-3)self.scheduler = CosineAnnealingWarmRestarts(self.optimizer,T_0=100,eta_min=1e-6)# 数据加载self.dataloader = DataLoader(self.dataset,batch_size=SafeConfig.batch_size,sampler=WeightedRandomSampler(self.dataset.priorities,num_samples=len(self.dataset),replacement=True),collate_fn=lambda b: {'indices': torch.LongTensor([x[0] for x in b]),'states': torch.stack([x[1] for x in b]),'actions': torch.stack([x[2] for x in b]),'next_states': torch.stack([x[3] for x in b]),'rewards': torch.stack([x[4] for x in b]),'dones': torch.stack([x[5] for x in b])},num_workers=4)# 训练状态self.conservative_weight = SafeConfig.conservative_initself.loss_history = deque(maxlen=100)def train_epoch(self, epoch):self.q_net.train()total_loss = 0.0for batch in self.dataloader:# 数据准备states = batch['states'].to(self.device)actions = batch['actions'].to(self.device)next_states = batch['next_states'].to(self.device)rewards = batch['rewards'].to(self.device) * SafeConfig.reward_scaledones = batch['dones'].to(self.device)# 目标Q值计算with torch.no_grad():# 带噪声的动作生成noise = torch.randn_like(actions) * SafeConfig.noise_scalenoise = torch.clamp(noise, -SafeConfig.noise_clip, SafeConfig.noise_clip)noisy_actions = actions + noise# 双Q学习target_q1, target_q2 = self.q_net.target_forward(next_states, noisy_actions)target_q = torch.min(target_q1, target_q2).squeeze(-1)y = rewards.squeeze(-1) + (1 - dones.squeeze(-1)) * SafeConfig.gamma * target_q# 当前Q值预测current_q1, current_q2 = self.q_net(states, actions)current_q1 = current_q1.squeeze(-1).clamp(-10.0, 50.0)current_q2 = current_q2.squeeze(-1).clamp(-10.0, 50.0)# 损失计算bellman_loss = 0.5 * (torch.nn.functional.huber_loss(current_q1, y, delta=1.0) +torch.nn.functional.huber_loss(current_q2, y, delta=1.0))# 保守正则项rand_acts = torch.randn_like(actions) * SafeConfig.noise_scaleq1_rand, q2_rand = self.q_net(states, rand_acts)conservative_loss = (q1_rand + q2_rand).mean() - (current_q1 + current_q2).mean()# 总损失loss = bellman_loss + self.conservative_weight * conservative_loss# 反向传播self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 2.0)self.optimizer.step()# 更新目标网络self.q_net._update_target(SafeConfig.tau)# 更新优先级td_errors = (current_q1 - y).detach().cpu().numpy()self.dataset.update_priorities(batch['indices'].numpy(), td_errors)total_loss += loss.item()# 调整保守权重self.conservative_weight = max(self.conservative_weight * SafeConfig.conservative_decay,SafeConfig.min_conservative)# 学习率调度self.scheduler.step()return total_loss / len(self.dataloader)def get_action(self, state):self.q_net.eval()state_norm = (state - self.dataset.obs_mean) / self.dataset.obs_stdstate_tensor = torch.FloatTensor(state_norm).unsqueeze(0).to(self.device)# 候选动作生成num_imitation = int(SafeConfig.candidate_samples * SafeConfig.imitation_ratio)imitation_idx = np.random.choice(len(self.dataset), num_imitation)imitation_acts = self.dataset.acts[imitation_idx]noise_acts = np.random.randn(SafeConfig.candidate_samples - num_imitation, self.action_dim)candidates = np.concatenate([imitation_acts, noise_acts])candidates = (candidates * self.dataset.act_std) + self.dataset.act_mean# 选择最优动作with torch.no_grad():state_batch = state_tensor.repeat(SafeConfig.candidate_samples, 1)candidate_tensor = torch.FloatTensor(candidates).to(self.device)candidate_norm = (candidate_tensor - self.dataset.act_mean) / self.dataset.act_stdq_values, _ = self.q_net(state_batch, candidate_norm)best_idx = torch.argmax(q_values)return candidates[best_idx.cpu().item()]
​
# 5. 训练执行
if __name__ == "__main__":trainer = SafeTrainer("D4RL/door/human-v2")print(f"初始化维度检查: state={trainer.state_dim}, action={trainer.action_dim}")try:for epoch in range(SafeConfig.total_epochs):loss = trainer.train_epoch(epoch)if (epoch + 1) % 20 == 0:print(f"Epoch {epoch+1:04d} | Loss: {loss:.2f} | "f"Conserv: {trainer.conservative_weight:.2f} | "f"LR: {trainer.scheduler.get_last_lr()[0]:.1e}")except KeyboardInterrupt:print("\n训练中断,保存检查点...")torch.save(trainer.q_net.state_dict(), "interrupted.pth")print("训练完成...")

四、关键代码解析

  1. 数据集加载

    • 使用 minari.load_dataset 加载离线数据集

    • 数据集包含状态、动作、奖励、终止标志等信息

  2. 保守正则化实现

    • 通过随机动作采样计算正则项

    • 超参数 $\alpha$ 控制保守程度

  3. 策略提取技巧

    • 采用基于 Q 值的启发式策略

    • 通过多候选动作采样提升稳定性


五、训练结果

运行代码将观察到:

初始化维度检查: state=39, action=28
Epoch 0020 | Loss: -46.52 | Conserv: 2.26 | LR: 2.7e-05
Epoch 0040 | Loss: -73.80 | Conserv: 2.05 | LR: 2.0e-05
Epoch 0060 | Loss: -73.50 | Conserv: 1.85 | LR: 1.1e-05
Epoch 0080 | Loss: -64.76 | Conserv: 1.67 | LR: 3.8e-06
Epoch 0100 | Loss: -54.37 | Conserv: 1.51 | LR: 3.0e-05
Epoch 0120 | Loss: -59.95 | Conserv: 1.37 | LR: 2.7e-05
Epoch 0140 | Loss: -60.11 | Conserv: 1.24 | LR: 2.0e-05
Epoch 0160 | Loss: -54.49 | Conserv: 1.12 | LR: 1.1e-05
Epoch 0180 | Loss: -46.11 | Conserv: 1.01 | LR: 3.8e-06
Epoch 0200 | Loss: -37.10 | Conserv: 0.92 | LR: 3.0e-05
Epoch 0220 | Loss: -37.56 | Conserv: 0.83 | LR: 2.7e-05
Epoch 0240 | Loss: -36.40 | Conserv: 0.75 | LR: 2.0e-05
Epoch 0260 | Loss: -31.79 | Conserv: 0.68 | LR: 1.1e-05
Epoch 0280 | Loss: -24.44 | Conserv: 0.61 | LR: 3.8e-06
Epoch 0300 | Loss: -17.06 | Conserv: 0.56 | LR: 3.0e-05
Epoch 0320 | Loss: -17.40 | Conserv: 0.50 | LR: 2.7e-05
Epoch 0340 | Loss: -16.91 | Conserv: 0.45 | LR: 2.0e-05
Epoch 0360 | Loss: -12.76 | Conserv: 0.41 | LR: 1.1e-05
Epoch 0380 | Loss: -7.27 | Conserv: 0.37 | LR: 3.8e-06
Epoch 0400 | Loss: -0.27 | Conserv: 0.34 | LR: 3.0e-05
Epoch 0420 | Loss: -1.47 | Conserv: 0.30 | LR: 2.7e-05
Epoch 0440 | Loss: -2.50 | Conserv: 0.30 | LR: 2.0e-05
Epoch 0460 | Loss: -2.87 | Conserv: 0.30 | LR: 1.1e-05
Epoch 0480 | Loss: -2.64 | Conserv: 0.30 | LR: 3.8e-06
Epoch 0500 | Loss: -2.30 | Conserv: 0.30 | LR: 3.0e-05
训练完成...


六、总结与扩展

本文基于 Minari 实现了 CQL 算法的核心逻辑,展示了离线强化学习在安全关键场景的应用价值。读者可尝试以下扩展:

  1. 添加策略网络实现 Actor-Critic 架构

  2. antmaze 等迷宫类数据集测试导航能力

  3. 实现更精确的 OOD(分布外)动作检测

在下一篇文章中,我们将探索 基于模型的强化学习(Model-Based RL),并实现 PETS 算法


注意事项

  1. 需先安装 minari 库:

    pip install "minari[all]"
  2. 数据集路径可通过 minari.list_datasets() 查看

  3. 调整 alpha 参数可平衡保守性与探索性

希望本文能帮助您理解离线强化学习的核心范式!欢迎在评论区分享您的实践心得。


http://www.ppmy.cn/devtools/171182.html

相关文章

信号处理等相关知识点

TDNN(时延神经网络)--CNN神经网络的基础 普通神经网络: 只包含一帧的特征向量 MFCC :用于语音特征提取的算法,提取出音色(很能区分不同人的说话声音)。 TDNN 滤波器:重要特征提取。 迁移学习 小波散射变换 (WST) 小波变换--傅里叶时间无限-》时间局域 点乘:求向…

【Linux文件IO】Linux中标准IO的API的描述和基本用法

Linux中标准IO的API的描述和基本用法 一、标准IO相关API1、文件的打开和关闭示例代码: 2、文件的读写示例代码:用标准IO(fread、fwrite)实现文件拷贝(任何文件均可拷贝) 3、文件偏移设置示例代码: 4、fgets fputs fget…

【Azure 架构师学习笔记】- Azure Networking(1) -- Service Endpoint 和 Private Endpoint

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Networking】系列。 前言 最近公司的安全部门在审计云环境安全性时经常提到service endpoint(SE)和priavate endpoint(PE)的术语,为此做了一些研究储备。 云…

计算机二级web易错点(7)-选择题

在 JavaScript 中,substr() 方法用于从字符串中提取子字符串。它接受两个参数,第一个参数表示开始提取的位置(索引从 0 开始),第二个参数表示要提取的字符数量。 在代码 var str"abcdefgh"; alert(str.subs…

大模型技术分类与技术演进研究

大模型技术分类与技术演进研究 人工智能领域的快速发展催生了多种大模型技术体系,其技术分类可从模型架构、训练范式、应用场景三个维度进行系统性划分。不同技术路径在算法原理、实现方式及产业应用中展现出显著差异,共同推动着AI技术边界的持续拓展。…

深度学习框架PyTorch——从入门到精通(6.2)自动微分机制

本节自动微分机制是上一节自动微分的扩展内容 自动微分是如何记录运算历史的保存张量 非可微函数的梯度在本地设置禁用梯度计算设置requires_grad梯度模式(Grad Modes)默认模式(梯度模式)无梯度模式推理模式评估模式(n…

加速还是安全?CDN与群联云防护的本质差异与适用场景

一、核心功能定位对比 维度传统CDN群联云防护核心目标内容加速(降低延迟、提升访问速度)安全防护(抵御DDoS/CC攻击、隐藏源站)技术重心缓存优化、边缘节点分发流量清洗、AI行为分析、加密隧道主要能力静态资源缓存、负载均衡攻击…

AWS 日本东京 EC2 VPS 性能、线路评测

原文链接更好的阅读体验:AWS 日本东京 EC2 VPS 性能、线路评测 本期详细记录 AWS EC2 日本区域 VPS 的性能和主要的大陆路由速度情况,方便自己以后查阅。这台 VPS 是 AWS 新用户十二个月免费机器,类型配置不高,主要是看网络情况&…