强化学习(RL)中数据采样存储与加载--TorchRL
在训练RL的过程中,使用 TorchRL 库初始化一个高性能、支持优先级的经验回放缓冲区。
使用 TorchRL 库初始化一个高性能、支持优先级的经验回放缓冲区:
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
这是关于在强化学习(RL)中使用 TensorDictPrioritizedReplayBuffer 进行经验回放(Experience Replay)和优先经验回放(Prioritized Experience Replay, PER)的配置说明。
📚 解释:为什么这样做?
1. Unified PER Support(统一支持)
现代的 RL 库(如 TorchRL)通常倾向于提供一个统一的、功能全面的数据结构来处理经验回放。
-
TensorDictPrioritizedReplayBuffer就是这个统一的结构。它天生支持复杂的 PER 机制(优先级计算和更新、IS 权重计算)。 -
好处:通过一个数据结构可以支持多种采样模式(均匀采样、PER),从而简化了代码库的结构和维护。
2. Uniform Sampling(均匀采样)
在很多基础 RL 算法或进行基准测试时,并不需要 PER 带来的复杂性。
-
均匀采样是标准的经验回放方式,它以相同的概率选择回放缓冲区中的任何经验。
-
实现方式:
-
通过设置
alpha=0,确保了 P(i)∝(priorityi)α 中的优先级指数为 0,即所有经验的采样概率 P(i) 相同。 -
通过设置
beta=0,确保了重要性采样权重 wi∝(1/(size⋅P(i)))β 中的权重指数为 0,即所有经验的权重 wi=1,无需修正。
-
总结
这种配置是在功能强大的 PER 基础设施上实现简单均匀采样的工程策略。开发者避免了为均匀采样编写一套单独的代码和数据结构,而是通过调整参数,让同一个类满足两种需求。
这段代码的核心作用是初始化一个高性能、支持优先级的在线经验回放缓冲区,用于存储智能体(Agent)与环境交互产生的数据,并为后续的强化学习训练提供数据
📝 使用方法示例
下面举个简单的例子。
1. 初始化
首先是您提供的初始化代码:
Python
import torch
from tensordict.nn import TensorDict
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.replay_buffers.transforms import MultiStepTransform
# 配置参数(假设来自 cfg 文件)
N_STEP = 3
GAMMA = 0.99
BUFFER_SIZE = 1_000_000
BATCH_SIZE = 256
ALPHA = 0.6 # 启用 PER
BETA_INIT = 0.4
PRIORITY_KEY = "_priority"
online_rb = TensorDictPrioritizedReplayBuffer(
storage=LazyTensorStorage(max_size=BUFFER_SIZE, device="cpu"),
alpha=ALPHA,
beta=BETA_INIT,
eps=1e-6,
priority_key=PRIORITY_KEY,
transform=MultiStepTransform(n_steps=N_STEP, gamma=GAMMA),
pin_memory=True,
prefetch=4,
batch_size=BATCH_SIZE,
)
2. 添加经验(Collect/Add Transitions)
在智能体与环境交互的循环中,将经验存入缓冲区。
Python
# 假设 obs, action, reward, next_obs, done 是从环境中采样得到的 TensorDict
# 注意:经验必须是 TensorDict 格式
# 步骤 t=0 的经验
td_experience_0 = TensorDict({
"observation": torch.randn(84, 84),
"action": torch.tensor(1),
"reward": torch.tensor(0.0),
"done": torch.tensor(False),
"next": {
"observation": torch.randn(84, 84),
"reward": torch.tensor(0.0),
"done": torch.tensor(False),
}
}, batch_size=[])
# 默认优先级 (当 alpha > 0 时, 缓冲区会给新经验设置一个高优先级)
online_rb.add(td_experience_0)
3. 采样数据(Sample Data for Training)
在训练步骤中,从缓冲区中采样一批数据用于模型更新。
Python
# 采样数据
# 缓冲区会根据优先级 (alpha) 和批大小 (batch_size) 采样数据
sampled_batch = online_rb.sample()
# 此时 sampled_batch 中已经包含了:
# 1. N 步回报 (G^(n)),由 MultiStepTransform 实时计算得到。
# 2. 重要性采样权重 (IS weights) W_i,存储在 "_is_weights" 键中(因为 beta > 0)。
# 3. 经验的索引,存储在 "_index" 键中,用于后续更新优先级。
print(f"采样批次形状: {sampled_batch.batch_size}")
print(f"IS 权重存在: {PRIORITY_KEY in sampled_batch.keys()}")
4. 更新优先级(Update Priorities - PER 核心)
在计算 TD 误差 δ 并完成模型更新后,我们需要根据误差的大小来更新被采样经验的优先级。
Python
# 假设 td_errors 是计算得到的 TD 误差的绝对值
td_errors = torch.rand(BATCH_SIZE) # 示例误差
# 使用采样批次的索引和新的优先级来更新缓冲区
online_rb.update_priority(
index=sampled_batch["_index"],
priority=td_errors ** ALPHA # 优先级 P = |TD_Error|^alpha
)
⚙️ TensorDictPrioritizedReplayBuffer 参数详解
1. 存储机制 (Storage)
| 参数 | 值 | 含义及作用 |
storage |
LazyTensorStorage(max_size=BUFFER_SIZE, device="cpu") |
定义了经验的实际存储方式和容量。 |
LazyTensorStorage |
是一种懒加载存储机制,它只在需要时才将数据加载到内存,对于存储大型数据(如高分辨率图像)非常高效。 | |
max_size=BUFFER_SIZE |
设置缓冲区可以存储的最大经验数量。当达到最大容量后,新的经验会替换掉最旧的经验(FIFO )。 | |
device="cpu" |
指定经验数据存储在 CPU 内存中,这能释放宝贵的 GPU 显存用于模型计算。 |
2. 优先级与采样策略 (PER Parameters)
| 参数 | 值 | 含义及作用 |
alpha ($\alpha$) |
ALPHA |
优先级指数。控制 TD 误差(优先级)对采样概率的影响程度。$\alpha=0$ 表示均匀采样,$\alpha>0$ 则误差越大的经验越容易被采样。 |
beta ($\beta$) |
BETA_INIT |
重要性采样 (IS) 权重指数。用于计算重要性采样权重,纠正 PER 带来的数据分布偏差。$\beta=0$ 表示不使用 IS 权重修正。$\beta$ 通常从一个较小的值开始,并随着训练逐步增加到 1。 |
eps ($\epsilon$) |
1e-6 |
一个小的正数,加到所有经验的优先级上,用于防止优先级为零的经验永远无法被采样,确保所有经验至少有微小的被学习机会。 |
priority_key |
PRIORITY_KEY |
指定在 TensorDict 数据中存储该经验优先级数值所使用的键名(如 "_priority")。 |
3. 数据预处理与转换 (Transform)
| 参数 | 值 | 含义及作用 |
transform |
MultiStepTransform(n_steps=N_STEP, gamma=GAMMA) |
定义了对从缓冲区采样的经验进行实时转换的流水线。 |
MultiStepTransform |
实现了 $N$-步回报计算。它会根据 n_steps 从缓冲区中提取连续的 $N$ 个经验,然后计算 $N$-步回报 $G^{(n)}$:$$G^{(n)} = \sum_{k=0}^{n-1} \gamma^k R_{t+k+1} + \gamma^n V(S_{t+n})$$这个转换极大地简化了 $N$-步 Q-learning 或 Sarsa 等算法的实现。 |
|
n_steps=N_STEP |
$N$ 步回报中 $N$ 的值。 | |
gamma=GAMMA |
折扣因子 $\gamma$ 的值。 |
4. 性能优化与批处理 (Performance & Batching)
| 参数 | 值 | 含义及作用 |
pin_memory |
True |
内存锁定。在 PyTorch 中,将 CPU 上的数据锁定到内存,可以显著加快数据从 CPU 到 GPU 的传输速度。 |
prefetch |
4 |
预取批次数量。在 GPU 忙于模型计算时,缓冲区会提前在后台加载并准备好 4 个批次(Batch)的数据。这减少了 GPU 的等待时间,提高了训练吞吐量。 |
batch_size |
BATCH_SIZE |
定义了每次从缓冲区中采样多少个经验用于训练。 |
下面给出一个更新例子:
import torch
from tensordict import TensorDict
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
# 1. 初始化缓冲池
# alpha: 决定优先级的程度 (0 为均匀采样, 1 为完全按优先级)
# beta: 重要性采样权重 (用于修正偏置)
# storage: 指定存储方式,LazyTensorStorage 可以在初次存入数据时自动推断 shape
rb = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
storage=LazyTensorStorage(max_size=1000),
priority_key="td_error", # 默认查找此键来更新优先级
batch_size=4 # 每次采样的大小
)
# 2. 准备模拟数据 (假设这是一次环境交互产生的 transition)
# TensorDict 就像一个带 Batch 维度的字典
data = TensorDict({
"observation": torch.randn(10, 4), # 10 条数据,每条维度为 4
"action": torch.randint(0, 2, (10, 1)),
"next_observation": torch.randn(10, 4),
"reward": torch.rand(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
}, batch_size=[10])
# 3. 将数据添加到缓冲池
# 返回的 indices 是数据在池中的位置索引
indices = rb.extend(data)
print(f"存入数据后的缓冲池长度: {len(rb)}")
# 4. 从缓冲池采样
# samples 会包含原始数据,以及 PER 自动生成的 "index" 和 "sample_log_weight" (用于重要性采样)
samples = rb.sample()
print(f"采样数据的键: {samples.keys()}")
print(f"采样索引: {samples['index']}")
# 5. 模拟计算 Loss 并更新优先级
# 在实际 RL 中,你会计算 TD-error,然后将其写回到 TensorDict 中
# 假设我们计算出新的 TD-error 如下:
td_error = torch.tensor([0.1, 0.9, 0.5, 0.2])
samples.set("td_error", td_error)
# 使用 update_tensordict_priority 一键更新
# 它会自动利用 samples 中的 "index" 和 "td_error" 来更新线段树
rb.update_tensordict_priority(samples)
print("优先级已更新!下次采样时,td_error 较大的样本被抽中的概率会更高。")
输出结果:
存入数据后的缓冲池长度: 10
采样数据的键: _StringKeys(dict_keys(['observation', 'action', 'next_observation', 'reward', 'done', 'index', '_weight']))
采样索引: tensor([4, 9, 8, 9])
优先级已更新!下次采样时,td_error 较大的样本被抽中的概率会更高
更多推荐
所有评论(0)