Dueling DQN 算法实战
$\text{Dueling DQN}$ 算法主要是改了网络结构,其他地方跟 $\text{DQN}$ 是一模一样的,如代码清单 1 所示。
class DuelingQNetwork(nn.Module):
def __init__(self, state_dim, action_dim,hidden_dim=128):
super(DuelingQNetwork, self).__init__()
# 隐藏层
self.hidden_layer = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# 优势层
self.advantage_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
# 价值层
self.value_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
x = self.hidden_layer(state)
advantage = self.advantage_layer(x)
value = self.value_layer(x)
return value + advantage - advantage.mean() # Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
最后我们展示一下它在 $\text{CartPole}$ 环境下的训练结果,如图 1 所示,完整的代码同样可以参考本书的代码仓库。
图 1 $\:$ $\text{CartPole}$ 环境 $\text{Dueling DQN}$ 算法训练曲线
由于环境比较简单,暂时还看不出来 $\text{Dueling DQN}$ 算法的优势,但是在复杂的环境下,比如 $\text{Atari}$ 游戏中,$\text{Dueling DQN}$ 算法的效果就会比 $\text{DQN}$ 算法好很多,读者可以在 $\text{JoyRL}$ 仓库中找到更复杂环境下的训练结果便于更好地进行对比。