添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
论文理解【Offline RL】——【DT】Decision Transformer: Reinforcement Learning via Sequence Modeling
已于 2023-02-06 15:38:13 修改
2022-12-23 04:01:24 阅读量 2.4k
  • 摘要:我们提出了一个将强化学习(RL)抽象为序列建模问题的框架。这使得我们能够利用 transformer 模型的简单性(simplicity)和可扩展性(scalability),以及 GPT-x 和 BERT 等语言建模方面的相关进展。特别地,我们提出了 Decision Transformer(DT),一个将 RL 问题转换为条件序列建模的架构。与之前基于值函数或计算策略梯度的 RL 方法不同,DT 只是通过 causally masked Transformer 来输出最优操作。 通过训练一个以期望 return、过去的状态和动作作为条件的自回归模型,DT 可以产生实现期望回报的未来行动 。尽管 DT 很简单,但它在 Atari、OpenAI gym 和 Key-to-Door 任务上媲美或超过了 SOTA 的 model-free offline RL Baseline 的性能。
  • Offline RL 是这样一种问题设定:Learner 可以获取由一批 episodes 或 transitions 构成的 固定交互数据集 ,要求 Learner 直接利用它训练得到一个好的策略,而且 禁止 Learner 和环境进行任何交互 ,示意图如下
    在这里插入图片描述
    关于 Offline RL 的详细介绍,请参考 Offline/Batch RL简介
  • Offline RL 是近年来很火的一个方向,下图显示了 2019 年以来该领域的重要工作,本文出现在 21 年,和同期的 TT 一样是最先 纯粹 使用 Transformer 模型解 Offline RL 问题的文章,可能也是最先把 Offline RL 当做序列建模问题来解的文章
    在这里插入图片描述
  • 本文使用的 Transformer 模型是 GPT,这是 Transformer 的 Decoder 部分,可以作为标准语言模型使用。所谓 “标准语言模型”,就是它会吃进去句尾的若干个词(token)然后预测下一个词是什么,这个过程可以反复进行从而实现文本生成的效果。GPT 相对传统语言模型的优势在于, GPT 模型结构中的 masked self attention 层利用 Q , K , V 矩阵在各个 token 间建立起了显式连接,从而解决了 RNN/LSTM 等传统序列模型的长跨度信息的遗忘问题 。详细说明可以参考 快速串联 RNN / LSTM / Attention / transformer / BERT / GPT

2. 本文方法

2.1 思想

2.2 方法

  • 标准设定下,offline 数据集中的一条轨迹形如
    t = t = t T r t

    下图给出了一个 return-to-go 的示意
    在这里插入图片描述
    这个 MDP 中每走一步会得到 -1 的 reward,agent 从随机位置出发按随机策略行动直到到达 goal 结束轨迹。每个圈边上的数字即是该轨迹中对应状态下的 return-to-go。在训练完成后的 generation 阶段,给定起点位置后,只须 在每一步决策时以训练数据中此位置收到的最大 return-to-go 作为条件选择动作,就能实现 offline 数据集中次优轨迹的拼接 。这里只是一个示意,作者提出的方法中对 return-to-go 条件的选取细节有所不同

  • 本文的设计的 DT 模型结构如下
    在这里插入图片描述
    从下往上看

    1. 处理后的轨迹按顺序输入,任意 R ^ t , s t , a t 使用相同的位置编码
    2. 另外这里嵌入层不是像传统 GPT 那样使用 nn.Embedding 配合 vocab_size 实现,而是直接用 nn.Linear 实现,这可能是因为很多 RL 环境状态、动作空间过大,或者直接就是连续的环境。另外,如果是图像形式的状态输入(如 Atria 环境),则用 nn.Conv2d 进行嵌入
    3. state token 嵌入之前首先要经过状态归一化,即先用数据集中所有 state 计算均值和标准差,然后将每个 state 减去均值除以标准差,使得所有 state 宏观上呈现 0 均值 1 标准差的分布
  • 接下来所谓的 “causal transformer” 其实就是 GPT 这种带有 mask self attention 结构的 Transformer 结构,不过 DT 里这个 mask 是按时刻 t 一组一组地遮盖的
  • 再上面的橘黄色块是一个就是线性层,把 GPT 结构得到的动态 token 嵌入向量转变为实际的 action。具体而言,如果环境的动作空间是离散的,这里就用线性层调整下维度然后加 softmax 优化交叉熵损失;如果动作空间是连续的,这里就用线性层直接输出动作然后优化 L2 损失
  • 上图中的红色虚线显示了推断阶段的 autoregress 过程,而训练阶段使用的是 teacher-forcing 。具体而言

    1. 训练时首先从 offline 数据集中采样一段连续 a t 计算损失来优化 DT
    2. 测试阶段, 人为选定一个目标的 return-to-go(比如 offline 数据集中最高的轨迹 return) ,从初始状态开始不断预测 action,配合环境交互实现自回归。 注意每一步交互得到真实 reward 后,就从初始设定的 return-to-go 中减去这个值,这样就自然地得到下一时刻的 return-to-go ,不断这样操作,DT 就能作为策略使用了
  • 下面给出伪代码
    在这里插入图片描述

3. 实验

  • 作者主要在 Atria 的离散环境和 gym 的经典连续控制环境中进行了测试,前者要求 long-term credit assignment 的能力,后者要求精细控制能力。另外还在一个 Key-To-Door 环境进行附加测试,该网格世界环境要求 agent 在必须第一个房间拿到钥匙并在第三个房间走到门才能得到 reward,奖励非常稀疏,是对 long-term credit assignment 能力要求非常高的特化环境。对比方法主要是一些 RL-based 的 TD-Learning 方法以及 BC,总体性能如下
    在这里插入图片描述
    可见 DT 媲美或超过过去的最好方法

3.1 Atari

  • Atari 环境的难点主要在于高维的视觉输入以及延迟 reward 导致的信用分配困难。具体而言,作者在各个测试环境首先训练一个 DQN,再从其 replay-buffer 中随机采样 1% 的轨迹(约50万 transition)作为 offline 数据集来训练。我们知道 GPT 这类 Transformer 模型由于 self attention 的计算量无法输入太长的序列,因此这里 DT 输入的长度依环境不同被限制为 50 或 30 ,四个环境下实验效果如下
    在这里插入图片描述
    这里的 BC 就是 DT 不加 return-to-go 条件,直接像普通 GPT 那样工作实现的。四个环境中三个表现不错

3.2 gym