Efficient Context and Schema Fusion Networks for Multi-Domain Dialogue State Tracking 概要 问题动机 对于多域 DST,由于候选状态的数量和对话长度的增加,数据稀疏性问题 是一个主要障碍。
主要贡献 为了有效地编码对话上下文,本文利用 以前的对话状态 (预测的)和 当前的对话话语 作为 DST 的 输入 。 为了 考虑不同域槽之间的关系 ,利用了 涉及先验知识的模式图 。 提出了一种新的 上下文 和模式 融合网络,利用内部和外部的注意机制对对话上下文和模式图进行编码。经过多层注意网络后,利用每个 域 - 插槽
节点的最终表示来预测相应的值,涉及上下文和模式信息。对于值预测,应用槽门分类器来决定对话中是否提到了域槽,然后利用基于 RNN 的值解码器来生成相应的值。 问题定义 假设存在 M 个域,\mathcal{D}={d_1,d_2,...,d_M} , 每一个域 d d \in \mathcal{D} 的槽集合为 \mathcal{S}^d = {s_1^d,s2^d,...,s^d{|\mathcal{S}^d|} } 因此,完全可能存在 J = \sum_{m=1}^M |\mathcal{S}^{d_m}| 个可能的 域 - 槽 对 \mathcal{O} = {O_1,O_2,...,O_J} 。由于不同的域可能包含相同的槽,因此我们表示所有不同的 N 个插槽,如 S={s_1、s_2、···、S_N} ,其中 N≤J 。
对话可以正式表示为 其中 A_t 为第 t 轮的系统话语,U_t 为用户话语,B_t 表示相应的对话状态。A_t 和 U_t 是单词序列,而 Bt 是一组域槽值三元组,例如 , 值 v{tj} 是第 t 回合第 j 域槽对的单词序列。DST 的目标是根据对话历史,正确地预测每个域 - 插槽对的值。
之前的大多数论文都选择连接对话历史中的所有单词,[A_1、U_1、A_2、U_2、···、A_t、U_t] 作为输入 。然而,这可能会导致 计算时间的增加 。在本文中, 仅利用当前的对话回合 A_t、Ut 和之前的对话状态 B{t−1} 来预测新的状态 Bt 。在训练过程中,则是使用了 B{t−1} 的真实值,而之前预测的对话状态将被用于推理阶段。
Schema Graph: 为了考虑不同 domain-slot 对之间的关系,并且利用它们作为一个额外的输入(先验知识)来指导上下文编码,本文将它们表示为一个模式图 G=(V,E) ,具有节点集 V 和边缘集 E。如下图:
在图中,有 三种节点 分别表示 所有域 \mathcal{D} 、 槽\mathcal{S} 和 域 - 槽 对 \mathcal{O} ,即 V=D∪S∪O 。利用不同节点之间的 四种无向边来编码先验知识:
(d,d^{'}) : 不同域之间的边(s,d) : 域与槽之间的边,只有在 d\in \mathcal{D} 且 s\in \mathcal{S}^d 时有链接(d,o) 和 (s,o) : 如果 domain-slot 对 o \in \mathcal{O} 是由 d\in \mathcal{D} 和 s\in \mathcal{S} 组成,那么就有两条边分别从 d 和 s 连接(虚线)到这个 域 - 插槽 节点(s,s') : 如果两个不同插槽(s \in \mathcal{S} 且 s^{'} \in \mathcal{S} )的 候选值重叠,它们之间也有一个边缘,例如,目的地和离开、离开和到达。模型 如上图,模型由 输入嵌入 、 上下文 - 模式融合网络 和状态预测模块 组成。下面将分别介绍它们:
Input Embeddings 输入分为 3 个部分,分别为 当前回合对话语句
、 先前回合的对话状态
, 模式图
。下面分别介绍其表示:
Dialogue Utterance 输入表示 X_t = [CLS] \oplus A_t \oplus; \oplus U_t \oplus [SEP] ,由于 [CLS] 被设计为捕获序列嵌入,因此它与其他 token 具有不同的段类型。X_t 的输入嵌入是 token 嵌入、 段嵌入 和位置嵌入 的值之和,与一般 Bert 的输入并无二致。 Previous Dialog State 输入表示 B{t-1} = [CLS] \oplus R^1{t-1} \oplus \cdots R^K{t-1} , K 表示的是 B{t-1} 中的元组数量,每一个元组由 d-s-v 三值来确定 R = d \oplus - \oplus s \oplus - \oplus v . Schema Graph 如前所述,模式图 G 由 M 个 域节点 、N 个 槽节点 和 J 个 域 - 槽节点 组成。这些节点的排列方式为: G = d_1 \oplus \cdots \oplus d_M \oplus s_1 \oplus \cdots \oplus s_N \oplus o_1 \oplus \cdots \oplus o_J 每个节点的嵌入都是通过平均相应域 / 插槽 / 域插槽中的 token 嵌入来初始化的。在图中省略了位置嵌入。图的边被表示为一个邻接矩阵
Context and Schema Fusion Network(CSFN) 利用上下文和模式融合网络 (CSFN) 逐层计算 Xt、B{t−1} 和 G 中的 token 或节点的隐藏状态。该网络是由 L 层 上下文模式感知自注意力层堆叠而成,下面将详细介绍其结构:
CSFN 详解:
C = MultiHead{\Theta} (Y,Z) C = GraphMultiHead_{\Theta} (Y,Z,A) 上下文与模式感知编码 每一层 CSFN 都由内部和外部的注意力组成,以包含不同类型的输入。第 i 层上模式图 G 的隐藏状态 H_i^G 更新如下: 简单解释一下过程:首先通过图多头注意力模块获得图中各节点之间的注意力矩阵 I{GG} , 在通过多头注意力模块获得输入语句与图中各节点的注意力矩阵 E{GX} , 以及 先前对话状态与图中各节点的注意力矩阵 E{GB} , 最终的模式图注意力矩阵 C{G} 为 Hi^G、I{GG}、E{GX}、E{GB} 之和后通过 LayerNorm 后的结果。 State Prediction 状态预测的目标是产生下一个对话状态 B_t ,它分为两个阶段:
Slot-gate Classification: 首先对每个域插槽节点应用一个插槽门分类器。分类器在 {NONE,DONTCARE,PTR} 中做出决策,其中 NONE 表示在这个回合没有提到域槽对,DONTCARE意味着用户可以接受这个槽的任何值,MARKDOWN_HASHccf95d5d6208e6821c61433d43848f16MARKDOWNHASH表示槽应该用一个值来处理, 其输入为模式图的最终隐藏状态: P{tj}^{gate} = softmax(FNN(H^G{L,M+N+j})) 该分类器的损失函数为:
\mathcal{L}{gate} = - \sum^T{t-1} \sum{j=1}^J \log (P^{gete}{tj} \cdot (\color{green}{y{tj}^{gate}})^T ) RNN-based Value Decoder: 对于标记为 MARKDOWN_HASHccf95d5d6208e6821c61433d43848f16MARKDOWNHASH 的域插槽对,进一步引入了一个基于 RNN 的值解码器来生成其值的 token 序列。本文用的是 GRU 来生成槽值。其隐藏状态 g{tj}^k \in R^{1\times d{model}} 通过输入一个词嵌入 e{tj}^k 迭代计算直到结束 token 被生成。初始隐藏状态为 g{tj}^0 = H{L,0}^{Xt} + H{L,0}^{B{t-1}} , 且 e{tj}^0 = H_{L,M+N+j}^G . 槽值生成器在第 k 步将隐藏状态转换为令牌词汇的概率分布,包括两部分:1)所有输入令牌的分布,2)输入词汇的分布。第一部分计算为: 第二部分计算为: 最终,使用软拷贝机制来获取所有候选令牌上的最终输出分布: 其损失函数为: 我们将两个模块联合训练,得到联合损失函数为:
实验 数据集 数据集采用 MultiWOZ 2.0 与 MultiWOZ 2.1
实验设置 CSFN 的隐藏向量维度 d_{model} = 400 ,注意力头数量 H = 4 , 所有嵌入维度都为 400。batch_size = 32 优化器选择 ADAM,learning-rate=1e-4
结果 如表所示,本文提出的 CSFN-DST 的性能可以优于除 SOM-DST 之外的其他模型。通过将我们的模式图与 SOM-DST 相结合,我们可以在开放词汇表设置中的 MultiWOZ2.0 和 2.1 上实现最先进的性能。此外,我们使用 BERT(伯特基不变的)的方法可以在预定义的基于本体的设置中获得与最佳系统相比的非常具竞争力的性能。当利用 BERT 时,我们使用 BERT 编码器初始化 CSFN 的所有参数,并使用 BERT 初始化令牌 / 位置嵌入。