Large Memory Layers with Product Keys
3 Learnable product key memories
我们考虑函数 $m:\mathbb{R}^{d} \rightarrow \mathbb{R}^{n}$ 它将作为神经网络中的一个层。m 的目的是在神经网络中提供大容量。
3.1 Memory design
高层结构: 我们 memory 整体结构如图 1 和图 2 所示。memory 由三个组件组成:查询网络、包含两组子键的键选择模块和值查找表。它首先计算与积键集进行比较的查询。对于每个积键,它计算一个分数并挑选分数最高的 k 个积键。然后,通过与所选键相关联的值的加权总和,使用 m(x)来产生输出。memory 所有参数都是可训练的,但每个输入只更新 k 个 memory slot。稀疏选择和参数更新使得训练和推理都非常有效。
查询生成:预处理网络: 函数 $q: x \mapsto q(x) \in \mathbb{R}^{d_{\mathrm{q}}}$ 简称查询网络,将 d 维输入映射到$d_q$维度的隐式空间。通常 q 是一个线性映射或者多层感知机,将维数从 d 减小到 512.由于 keys 是随机初始化的,因此它们相对均匀地占据空间。在查询网络的顶部添加 BN 层,有助于训练时 key 的收敛。我们在 4.5 节中的消融实验证实了这一见解。
标准键分配与加权: 设 q(x)表示查询网络,$\mathcal{T}_{k}$表示 top-k 运算符。给定键集 $\mathcal{K}=\left\{k_{1}, \ldots, k_{|\mathcal{K}|}\right\}$ 由$|\mathcal{K}|$个$d_q$维向量和一个输入 x 组成,我们使用查询$q(x)$选择前 k 个键来最大化内积:
$\mathcal{I}$表示 k 个最相似的键(其中相似度量采用的是内积),w 是表示与所选键关联的归一化分数的向量。所有这些操作都可以使用自动微分机制来实现,使我们的层可以在神经网络中的任何位置插入。
操作(2)和(3)仅取决于前 k 个索引,因此在计算上是有效的。相比之下,等式(1)的详尽比较对于大 memories 来说效率不高,因为它设计计算$|K|$个内积。为了规避这个问题,我们采用一组结构化的 keys 集合,称为 product keys。
product key set 被定义为两个向量 codebooks $C$和$C^\prime$的外积,与向量连接操作符有关。
这个笛卡尔积结构所引起的 keys 总数是$|\mathcal{K}|=|\mathcal{C}| \times\left|\mathcal{C}^{\prime}\right|$。集合 $C$ 和 $C^\prime$都包含一组维度为 $d_q/2$ 的子 keys。我们利用这种结构来有效地计算最接近的 keys $\mathcal{I} \in(1, \ldots, K)$。首先,我们将查询$q(x)$拆分为两个子查询 $q1$ 和 $q2$。然后,我们计算 $C$ 中最接近子查询 $q1$(相应 $q2$)的那 k 个子键(相应的 $C^\prime$):
可保证 K 中 k 个最相似的键的形式为$\left\{\left(c_{i}, c_{j}^{\prime}\right) \mid i \in \mathcal{I}_{\mathcal{C}}, j \in \mathcal{I}_{\mathcal{C}^{\prime}}\right\}$。图 2 显示了具有 key 选择过程的 product keys 示例。
本文作者 : HeoLis
原文链接 : https://ishero.net/Large%20Memory%20Layers%20with%20Product%20Keys.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!
学习、记录、分享、获得