PER DQN 算法实战
伪代码
$\text{PER DQN}$ 算法的核心看起来简单,就是把普通的经验回放改进成了优先级经验回放,但是实现起来却比较复杂,因为我们需要实现一个 $\text{SumTree}$ 结构,并且在模型更新的时候也需要一些额外的操作,因此我们先从伪代码开始,如图 $\text{8-7}$ 所示。
SumTree 结构
如代码清单 $\text{8-6}$ 所示,我们可以先实现 $\text{SumTree}$ 结构。
class SumTree:
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 capacity - 1) # 树的大小,叶节点数等于capacity
self.data = np.zeros(capacity, dtype=object)
self.data_pointer = 0
def add(self, priority, data):
'''向树中添加数据
'''
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data
self.update(tree_idx, priority)
self.data_pointer += 1
if self.data_pointer >= self.capacity:
self.data_pointer = 0
def update(self, tree_idx, priority):
'''更新树中节点的优先级
'''
change = priority - self.tree[tree_idx]
self.tree[tree_idx] = priority
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def get_leaf(self, v):
'''根据给定的值v,找到对应的叶节点
'''
parent_idx = 0
while True:
left_child_idx = 2
parent_idx + 1
right_child_idx = left_child_idx + 1
if left_child_idx >= len(self.tree):
leaf_idx = parent_idx
break
else:
if v <= self.tree[left_child_idx]:
parent_idx = left_child_idx
else:
v -= self.tree[left_child_idx]
parent_idx = right_child_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def max_priority(self):
'''获取当前树中最大的优先级
'''
return self.tree[-self.capacity:].max()
@property
def total_priority(self):
'''获取当前树中所有优先级的和
'''
return self.tree[0]
其中,除了需要存放各个节点的值tree之外,我们需要定义要给data来存放叶子节点的样本。此外,add函数用于添加一个样本到叶子节点,并更新其父节点的优先级;update函数用于更新叶子节点的优先级,并更新其父节点的优先级;get_leaf函数用于根据优先级的值采样对应区间的叶子节点样本;get_data函数用于根据索引获取对应的样本。
优先级经验回放
基于 $\text{SumTree}$ 结构,并结合优先级采样和重要性采样的技巧,如代码清单 $\text{8-7}$ 所示。
class ReplayBuffer:
def __init__(self, cfg):
self.capacity = cfg.buffer_size
self.alpha = cfg.per_alpha
self.beta = cfg.per_beta
self.beta_increment_per_sampling = cfg.per_beta_increment_per_sampling
self.epsilon = cfg.per_epsilon
self.tree = SumTree(self.capacity)
def push(self, transition):
# max_prio = self.tree.tree[-self.tree.capacity:].max()
max_prio = self.tree.max_priority
if max_prio == 0:
max_prio = 1.0
self.tree.add(max_prio, transition)
def sample(self, batch_size):
self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
minibatch = []
idxs = []
segment = self.tree.total_priority / batch_size
priorities = []
for i in range(batch_size):
a = segment i
b = segment (i + 1)
s = np.random.uniform(a, b)
idx, p, data = self.tree.get_leaf(s)
minibatch.append(data)
idxs.append(idx)
priorities.append(p)
sampling_probabilities = priorities / self.tree.total_priority
is_weight = np.power(self.tree.capacity sampling_probabilities, -self.beta)
is_weight /= is_weight.max()
batch = list(zip(minibatch))
return tuple(map(lambda x: np.array(x), batch)), idxs, is_weight
def update_priorities(self, idxs, priorities):
for idx, priority in zip(idxs, priorities):
self.tree.update(idx, (np.abs(priority) + self.epsilon) * self.alpha)
def __len__(self):
return len(self.tree.data)
我们可以看到,优先级经验回放的核心是 SumTree,它可以在 $O(\log N)$ 的时间复杂度内完成添加、更新和采样操作。在实践中,我们可以将经验回放的容量设置为 $10^6$,并将 $\alpha$ 设置为 $0.6$,$\epsilon$ 设置为 $0.01$,$\beta$ 设置为 $0.4$,$\beta_{\text{step}}$ 设置为 $0.0001$。 当然我们也可以利用 Python 队列的方式实现优先级经验回放,形式上会更加简洁,并且在采样的时候减少了 for 循环的操作,会更加高效,如代码清单 $\text{8-8}$ 所示。
class PrioritizedReplayBufferQue:
def __init__(self, cfg):
self.capacity = cfg.buffer_size
self.alpha = cfg.per_alpha # 优先级的指数参数,越大越重要,越小越不重要
self.epsilon = cfg.per_epsilon # 优先级的最小值,防止优先级为0
self.beta = cfg.per_beta # importance sampling的参数
self.beta_annealing = cfg.per_beta_annealing # beta的增长率
self.buffer = deque(maxlen=self.capacity)
self.priorities = deque(maxlen=self.capacity)
self.count = 0 # 当前存储的样本数量
self.max_priority = 1.0
def push(self,exps):
self.buffer.append(exps)
self.priorities.append(max(self.priorities, default=self.max_priority))
self.count += 1
def sample(self, batch_size):
priorities = np.array(self.priorities)
probs = priorities/sum(priorities)
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
weights = (self.countprobs[indices])*(-self.beta)
weights /= weights.max()
exps = [self.buffer[i] for i in indices]
return zip(exps), indices, weights
def update_priorities(self, indices, priorities):
priorities = np.abs(priorities)
priorities = (priorities + self.epsilon) ** self.alpha
priorities = np.minimum(priorities, self.max_priority).flatten()
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return self.count
最后,我们可以将优先级经验回放和 $\text{DQN}$ 结合起来,实现一个带有优先级的 $\text{DQN}$ 算法,并展示它在 $\text{CartPole}$ 环境下的训练结果,如图 $\text{8-8}$ 所示。