转自 | 新智元
编辑:好困 小咸鱼
11月18日,谷歌联合DeepMind发布了TensorFlow GNN(图神经网络)。
目前,谷歌已经在诸如垃圾邮件检测、流量估计以及YouTube内容标签等环境中用上了这个库的早期版本。
为什么要用GNN?
图(Graph)是用于表示对象之间关联关系的一种抽象数据结构,使用节点/顶点(Node/Vertex)和边(Edge)进行描述:顶点表示对象,边表示对象之间的关系。
而在现实世界和工程系统中,图无处不在。
比如,一组物体、地点或人以及它们之间的联系通常可以用图形来描述。
通常,在机器学习问题中看到的数据是结构化的或关系型的,因此也可以用图来描述。
经过几十年的基础研究,GNN已经在很多的领域都取得了进展,如流量预测、谣言和假新闻检测、疾病传播建模、物理模拟以及理解分子为什么有气味。
图模拟不同类型数据之间的关系:网页(左)、社交联系(中)或分子(右)
通过GNN,就可以回答那些关于图的多种特征的问题。比如在图中观察到的各种不同的「形状」:图中的圆圈,可能代表子分子,也可能代表密切的社会关系。
在节点级的任务中,GNN可以对图的节点进行分类,并预测图中的分区和亲和力,类似于图像分类或分割。
在边级别的任务中,可以使用GNN来发现实体之间的连接,比如用GNN「修剪」图中的边,从而识别场景中对象的状态。
TF-GNN的结构
TF-GNN为在TensorFlow中实现GNN模型提供了构建模块。
除了建模API之外,TF-GNN还围绕着处理图数据的困难任务提供了大量的工具:基于Tensor的图数据结构,数据处理管道,以及一些供用户快速上手的示例模型。
组成工作流程的TF-GNN的各个部分
TF-GNN库的初始版本包含了许多实用程序和功能,包括:
使用示例
比如,使用TF-GNN Keras API建立一个模型,并根据用户观看的内容和喜欢的类型向其推荐电影。
通过使用ConvGNNBuilder方法来指定边缘和节点配置的类型,即对边缘使用WeightedSumConvolution。每次通过GNN时,都将通过Dense互连层来更新节点值:
import tensorflow as tf
import tensorflow_gnn as tfgnn
# Model hyper-parameters:
h_dims = {'user': 256, 'movie': 64, 'genre': 128}
# Model builder initialization:
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge_set_name: WeightedSumConvolution(),
lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(h_dims[node_set_name]))
)
# Two rounds of message passing to target node sets:
model = tf.keras.models.Sequential([
gnn.Convolve({'genre'}), # sends messages from movie to genre
gnn.Convolve({'user'}), # sends messages from movie and genre to users
tfgnn.keras.layers.Readout(node_set_name="user"),
tf.keras.layers.Dense(1)
])
此外,还可以在某些场景下让GNN使用一个更强大的自定义模型架构。
例如,指定某些电影或流派在推荐时拥有更多的权重。
那么,就可以通过自定义图卷积来生成一个更高级的GNN。
在下面的这段代码中,就用WeightedSumConvolution类来汇集边的值,并作为所有边的权重之和:
class WeightedSumConvolution(tf.keras.layers.Layer):
"""Weighted sum of source nodes states."""
def call(self, graph: tfgnn.GraphTensor,
edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
messages = tfgnn.broadcast_node_to_edges(
graph,
edge_set_name,
tfgnn.SOURCE,
feature_name=tfgnn.DEFAULT_STATE_NAME)
weights = graph.edge_sets[edge_set_name]['weight']
weighted_messages = tf.expand_dims(weights, -1) * messages
pooled_messages = tfgnn.pool_edges_to_node(
graph,
edge_set_name,
tfgnn.TARGET,
reduce_type='sum',
feature_value=weighted_messages)
return pooled_messages
尽管卷积是在只考虑源节点和目标节点的情况下编写的,但是TF-GNN确保了它的适用性,并且可以无缝地在异构图(具有各种类型的节点和边)上工作。
安装说明
这是目前安装tensorflow_gnn预览版的唯一方法。强烈建议使用虚拟环境。
$> git clone https://github.com/tensorflow/gnn.git tensorflow_gnn
TF-GNN需要用到TensorFlow 2.7中的一个功能:tf.ExtensionTypes。
$> pip install tensorflow
构建TF-GNN的源代码需要用到Bazel。
TF-GNN将使用GraphViz作为可视化工具。安装方法因操作系统而异,例如,在Ubuntu中:
$> sudo apt-get install graphviz graphviz-dev
$> cd tensorflow_gnn && python3 -m pip install
参考资料:
https://blog.tensorflow.org/2021/11/introducing-tensorflow-gnn.html?m=1
https://github.com/tensorflow/gnn
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有