技术前沿
作者:Steeve Huang
编译整理:萝卜兔
最近,Graph Neural Network(GNN)在很多领域日益普及,包括社交网络、知识图谱、推荐系统甚至于生命科学。GNN在对节点关系建模方面表现十分突出,使得相关的研究领域取得了一定突破。本文旨在对GNN做一个简单的介绍,并介绍两种前沿算法,DeepWalk和GraphSage。
Graph
在学习GNN之前,先让我们了解一下什么是Graph。在计算机科学中,graph是一种数据结构,由两部分组成,顶点和边。图G可以通过顶点集V和边集E来描述:
根据顶点之间是否有方向,边可以分为无向和有向。
有向图
顶点又称节点,本文中两者可以互换。
Graph Neural Network
GNN是直接在图数据结构上运行的神经网络。GNN的典型应用便是节点分类。图中的每个节点都有一个标签,我们希望不需要标注数据,可以预测新的节点标签。本节将讲解论文《The graph neural network model》中的GNN算法,算得上第一个GNN。
在节点分类问题中,每个节点v的特征用xv表示,并且和标签tv相关联。给定部分标记的图G,目标是利用这些标记的节点来预测未标记节点的标签。网络学会用d维向量(状态)hv表示每个节点,其中包含其邻域信息。
其中,xco[v]表示与v连接的边的特征,hne[v]表示v的相邻节点嵌入特征,xne[v]表示v的相邻节点的特征。函数f是将这些输入投影到d维空间的传递函数。由于我们正在寻找hv的唯一解,可以应用Banach不动点定理并将上述等式重写为迭代更新过程。
H和X分别表示所有h和x的连接。通过将状态hv以及特征xv传递给输出函数g来计算GNN的输出。
f和这里的g都可以解释为前馈全连接神经网络。 L1损失可以直接表述如下:
再通过梯度下降优化。上述的GNN算法有三个限制:
1、如果放宽“固定点”的假设,可以利用多层感知机来学习更稳定的表示,并删除迭代更新过程。这是因为,在该提议中,不同的迭代使用传递函数f的相同参数,而不同MLP层中的不同参数允许分层特征提取。
2、它不能处理边信息(例如知识图中的不同边可能表示节点之间的不同关系)。
3、固定点可以阻止节点分布的多样化,因此可能不适合学习表示节点。
当然,已经有几种GNN的变体来解决上述问题,我在这里不展开讲解了。
DeepWalk
DeepWalk是以无监督的方式学习node embedding的算法。它的训练过程非常类似word embedding。动机是图表中的节点和语料库中的单词的分布遵循幂定律,如下图所示:
算法包括两个步骤:
1、在图中的节点上进行随机游走以生成节点序列;
2、使用skip-gram,根据步骤1中生成的节点序列学习每个节点的嵌入。
在随机游走的每个时间步骤,从前一节点的邻居统一采样下一个节点。然后将每个序列截短为长度为2|w|+1的子序列,其中w表示skip-gram中的窗口大小。
在提出DeepWalk的论文中,分层softmax用于解决由于节点数量庞大而导致的softmax计算成本高昂的问题。为了计算每个单独输出元素的softmax值,我们必须计算所有元素k的所有exk。
因此,原始softmax的计算时间为O(|V|),其中V表示图中的顶点集。
分层softmax利用二叉树来处理问题。在这个二叉树中,所有叶子(上图中的v1,v2,...,v8)都是图中的顶点。在每个内部节点中,有一个二元分类器来决定选择哪条路径。为了计算给定顶点vk的概率,可以简单地计算沿着从根节点到离开vk的路径中的每个子路径的概率。由于每个节点的子概率为1,因此所有顶点的概率之和等于1的特性仍然保持在分层softmax中。现在,元素的计算时间减少到O(log|V|),因为二叉树的最长路径由O(log|n|)限定,其中是n叶子的数量。
分层Softmax
在训练DeepWalk GNN之后,模型已经学习了每个节点的良好表示,如下图所示。不同的颜色表示输入图中的不同标签。我们可以看到,在输出图形(嵌入2维)中,具有相同标签的节点聚集在一起,而具有不同标签的大多数节点被正确分开。
然而,DeepWalk的主要问题是缺乏泛化能力。每当有新节点加入时,它必须重新训练模型以表示该节点。因此,这种GNN不适用于图中节点不断变化的动态图。
GraphSage
GraphSage提供解决上述问题的方案,以归纳方式学习每个节点的嵌入。具体而言,每个节点由其邻域的聚合表示。因此,即使在训练时间内看不到的新节点出现在图中,它仍然可以由其相邻节点正确地表示。下面显示了GraphSage的算法。
外层循环表示更新迭代次数,而hvk表示更新迭代k时节点v的特征。在每次更新迭代时,基于聚合函数,前一次迭代中v和v邻域的特征以及权重矩阵Wk来更新hvk。本文提出了三种聚合函数:
1. Mean aggregator
平均聚合器获取节点及其所有邻域的特征的平均值。
与原始方程相比,它删除了上述伪代码中第5行的连接操作。此操作可以被视为“跳过连接”,本文稍后将证明可以在很大程度上提高模型的性能。
2. LSTM aggregator
由于图中的节点没有任何顺序,因此它们通过置换这些节点来随机分配顺序。
3. Pooling aggregator
此运算符在相邻集上执行逐元素池化功能。下面显示了max-pooling示例:
可以用mean-pooling或任何其他对称池化函数替换。文章指出pooling aggregator执行最佳,而mean-pooling和max-pooling具有相似的性能。本文使用max-pooling作为默认聚合函数。
损失函数定义如下:
其中u和v共同出现在固定长度的随机游走中,而vn是不与u共同出现的负样本。这种损失函数鼓励具有类似嵌入的节点更接近,而那些相距很远的节点在投影空间中分离。通过这种方法,节点将获得越来越多关于其邻域的信息。
GraphSage通过聚合其附近的节点,可以为看不见的节点生成可表示的嵌入。它允许将节点嵌入应用于涉及动态图的域,其中图的结构不断变化。例如,Pinterest采用了GraphSage的扩展版本PinSage作为其内容发现系统的核心。
总结
我们已经学习了图形神经网络,DeepWalk和GraphSage的基础知识。 GNN在复杂图形结构建模中的强大功能确实令人惊讶。鉴于其有效性,我相信,在不久的将来,GNN将在人工智能的发展中发挥重要作用。
相关参考
https://towardsdatascience.com/a-gentle-introduction-to-graph-neural-network-basics-deepwalk-and-graphsage-db5d540d50b3
参考论文:
https://arxiv.org/pdf/1812.08434.pdf
http://www.perozzi.net/publications/14_kdd_deepwalk.pdf
https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.1015.7227&rep=rep1&type=pdf
>
领取专属 10元无门槛券
私享最新 技术干货