添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

【导语】:在深度强化学习第四篇中,讲了Policy Gradient的理论。通过最终推导得到的公式,本文用PyTorch简单实现以下,并且尽可能搞清楚torch.distribution的使用方法。代码参考了LeeDeepRl-Notes中的实现。

1. 复习

\[\theta \leftarrow \theta+\eta \nabla \bar{R_\theta} \\\nabla \bar{R_\theta}=\frac{1}{N}\sum^N_{n=1}\sum^{T_n}_{t=1}R(\tau^n)\nabla log p_\theta(a_t^n|s_t^n) \]

\(\theta\) 代表模型的参数,第一行公式代表了模型进行更新的方法, \(\eta\) 代表的是学习率。

第二行是推导得到的,和CrossEntropy可以对照着理解记忆。

2. Torch.Distributions

distributions包主要是实现了参数化的概率分布和采样函数。参数化是为了让模型能够具有反向传播的能力,这样才可以用随机梯度下降的方法来进行优化。随机采样的话没办法直接反向传播,有两个方法,REINFORCE和pathwise derivative estimator。

Torch中提供两个方法,sample()和log_prob(),就可以实现REINFORCE

\[\Delta \theta=\alpha r \frac{\partial \log p\left(a \mid \pi^{\theta}(s)\right)}{\partial \theta} \]

\(\theta\) 是模型参数, \(\alpha\) 代表的是学习率,r代表reward, \(p\left(a \mid \pi^{\theta}(s)\right)\) 代表在状态s下,使用策略 \(\pi^{\theta}\) 采取a动作的概率。

2.1 REINFORCE

实现的时候,会先从网络输出构造一个分布,然后从分布中采样一个action,将action作用于环境,然后使用log_prob()函数来构建一个损失函数,代码如下(PyTorch官方提供):

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

对照一下,这个-m.log_prob(action)应该对应上述公式:\(\log p\left(a \mid \pi^{\theta}(s)\right)\), 加负号的原因是,在公式中应该是实现的梯度上升算法,而loss一般使用随机梯度下降的,所以加个负号保持一致性。

2.2 PathWise Derivative Estimator

这是一种重参数化技巧,主要是通过调用rsample()函数来实现的,参数化随机变量可以通过无参数随机变量的参数化确定性函数来构造。参数化以后,采样过程就变得可微分了,也就支持了网络的后向传播。实现如下(PyTorch官方实现):

params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

这样的话,可以直接对-reward使用随机梯度下降,因为rsample后可微分,可以后向传播。

3. 源码

主要看agent对象的实现:

class PolicyGradient:
    def __init__(self, state_dim, device='cpu', gamma=0.99, lr=0.01, batch_size=5):
        self.gamma = gamma
        self.policy_net = FCN(state_dim)
        self.optimizer = torch.optim.RMSprop(
            self.policy_net.parameters(), lr=lr)
        self.batch_size = batch_size
    def choose_action(self, state):
        state = torch.from_numpy(state).float()
        state = Variable(state)
        probs = self.policy_net(state)
        m = Bernoulli(probs)
        action = m.sample()
        action = action.data.numpy().astype(int)[0]  # 转为标量
        return action
    def update(self, reward_pool, state_pool, action_pool):
        # Discount reward
        running_add = 0 # 就是那个有discount的公式
        for i in reversed(range(len(reward_pool))): # 倒数
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add
        # 得到G
        # Normalize reward
        reward_mean = np.mean(reward_pool)
        reward_std = np.std(reward_pool)
        for i in range(len(reward_pool)):
            reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std
        # 归一化
        # Gradient Desent
        self.optimizer.zero_grad()
        for i in range(len(reward_pool)): # 从前往后
            state = state_pool[i] 
            action = Variable(torch.FloatTensor([action_pool[i]]))
            reward = reward_pool[i]
            state = Variable(torch.from_numpy(state).float())
            probs = self.policy_net(state)
            m = Bernoulli(probs)
            # Negtive score function x reward
            loss = -m.log_prob(action) * reward # 核心
            # print(loss)
            loss.backward()
        self.optimizer.step()
    def save_model(self, path):
        torch.save(self.policy_net.state_dict(), path)
    def load_model(self, path):
        self.policy_net.load_state_dict(torch.load(path))

可以看到核心实现是以下几句:

state = Variable(torch.from_numpy(state).float())
probs = self.policy_net(state)
m = Bernoulli(probs)
# Negtive score function x reward
loss = -m.log_prob(action) * reward # 核心
# print(loss)
loss.backward()

这里采用的是伯努利分布,二项分布,举个例子:

Example::
        >>> m = Bernoulli(torch.tensor([0.3]))
        >>> m.sample()  # 30% chance 1; 70% chance 0
        tensor([ 0.])

采样结果是0或者1,1对应的概率是p,0对应概率是1-p。

为神马要用这个伯努利分布呢?因为这个这个问题是CartPole-v0 ,其动作空间只有0或1,所以这里采用了Bernoulli,其他情况要使用不同的分布才能满足要求。

得到了采样结果以后,就是用了第二节提到的REINFORCE的方法计算loss,进行loss反向传播。

4. 总结

简单介绍了以下如何使用,但并没有深究背后的原理,这个系列会继续更新,同时我也会继续加强我的数学功底。

代码改变世界