前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >知识扩展---图神经网络GraphSAGE用于空间转录组的分子聚类

知识扩展---图神经网络GraphSAGE用于空间转录组的分子聚类

原创
作者头像
追风少年i
发布2025-02-05 15:20:55
发布2025-02-05 15:20:55
13400
代码可运行
举报
运行总次数:0
代码可运行

作者,Evil Genius

开工第一天,岁数又加一。

今天我们分享的内容是,GraphSAGE。

在昨天的文章文献分享---空间转录组学鉴定与肺纤维化远端肺重构相关的分子生态位失调(Xenium + HD)中,我们可以看到文中对空间转录组的分子聚类采用的是

从文章的信息可以获知,1、GraphSAGE是一种图神经网络;2、GraphSAGE,基于转录组数据的空间位置训练了一个图神经网络模型以聚合局部邻域信息并定义一个嵌入空间,该空间为数据集中的所有单个转录组提供了新的表示。然后,应用高斯混合模型在嵌入空间中聚类转录组,并识别生态位,使用共识方法将细胞分配到这些生态位中。

那么我们需要首先知道什么是图神经网络GraphSAGE。

GraphSAGE(Graph Sample and Aggregated)是一种用于图节点嵌入学习的图神经网络模型。它通过采样和聚合的方式,将邻居节点的信息聚合到目标节点上,从而学习节点的表示向量。

GraphSAGE的核心思想是从目标节点的邻居节点中采样一部分节点,然后通过聚合操作将邻居节点的特征信息整合到目标节点上。这样一方面减少了计算复杂度,另一方面也保留了图结构中的信息。GraphSAGE模型包含以下几个重要的步骤:

采样:针对每个节点,从其邻居节点中随机采样一定数量的节点作为采样节点集合。采样的目的是探索节点的局部结构,以便更好地捕捉节点的特征。 编码器:对于每个节点,通过一个编码器将其自身的特征向量转换为一个低维的表示向量。编码器可以是一个全连接层、一个卷积神经网络等。 聚合器:通过聚合操作将采样节点集合中的特征向量聚合到目标节点上。常用的聚合方法有均值聚合、最大池化等。聚合的过程可以通过多层的聚合器进行迭代。

通过多层的编码器和聚合器,GraphSAGE能够逐渐聚合更多层次的邻居节点信息,并且逐渐扩大目标节点对邻居节点的感知范围。最终,每个节点都能够获得一个表示其自身和周围结构的嵌入向量,该向量可以用于下游的节点分类、链接预测等任务。

GraphSAGE在图节点嵌入学习任务中具有较好的性能,能够有效地学习图结构中的节点特征。它可以用于社交网络分析、推荐系统、图像分析等领域,对于挖掘和分析图结构数据具有重要的应用价值。

来到这个地方,我们再扩展一下

在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。

GCN虽然能提取图中顶点的embedding,但是存在一些问题:

GCN的基本思想: 把一个节点在图中的高纬度邻接信息降维到一个低维的向量表示。

GCN的优点: 可以捕捉graph的全局信息,从而很好地表示node的特征。

GCN的缺点: 第一,GCN需要将整个图放到内存和显存,这将非常耗内存和显存,处理不了大图;第二,GCN在训练时需要知道整个图的结构信息(包括待预测的节点), 这在现实某些任务中也不能实现(比如用今天训练的图模型预测明天的数据,那么明天的节点是拿不到的)。

为了解决GCN的两个缺点问题,GraphSAGE被提了出来。在介绍GraphSAGE之前,先介绍一下Inductive learning和Transductive learning。注意到图数据和其他类型数据的不同,图数据中的每一个节点可以通过边的关系利用其他节点的信息。这就导致一个问题,GCN输入了整个图,训练节点收集邻居节点信息的时候,用到了测试和验证集的样本,我们把这个称为Transductive learning。然而,我们所处理的大多数的机器学习问题都是Inductive learning,因为我们刻意的将样本集分为训练/验证/测试,并且训练的时候只用训练样本。这样对图来说有个好处,可以处理图中新来的节点,可以利用已知节点的信息为未知节点生成embedding,GraphSAGE就是这么干的。

GraphSAGE是一个Inductive Learning框架,具体实现中,训练时它仅仅保留训练样本到训练样本的边,然后包含Sample和Aggregate两大步骤,Sample是指如何对邻居的个数进行采样,Aggregate是指拿到邻居节点的embedding之后如何汇聚这些embedding以更新自己的embedding信息。下图展示了GraphSAGE学习的一个过程:

图1

简单的逻辑过程如下

第一步,对邻居采样;

第二步,采样后的邻居embedding传到节点上来,并使用一个聚合函数聚合这些邻居信息以更新节点的embedding;

第三步,根据更新后的embedding预测节点的标签。

算法假设模型已经过训练,参数是固定的。

算法是,在每次迭代(或搜索深度),顶点从它们的局部邻居聚合信息,并且随着这个过程的迭代,顶点会从越来越远的地方获得信息。

到这里还有多少没有阵亡的小伙伴?

算法描述了在整个图上生成embedding的过程,其中

首先,(line1)算法首先初始化输入的图中所有节点的特征向量,(line3)对于每个节点 ,拿到它采样后的邻居节点 后,(line4)利用聚合函数聚合邻居节点的信息,(line5)并结合自身embedding通过一个非线性变换更新自身的embedding表示。

注意到算法里面的 ,它是指聚合器的数量,也是指权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。网络的层数可以理解为需要最大访问的邻居的跳数(hops),比如在图1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。为了更新红色节点,首先在第一层(k=1),会将蓝色节点的信息聚合到红色解节点上,将绿色节点的信息聚合到蓝色节点上。在第二层(k=2)红色节点的embedding被再次更新,不过这次用到的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息,也就是两跳信息。

最后来看看示例代码

首先采样

代码语言:javascript
代码运行次数:0
复制
import numpy as np

def sampling(src_nodes, sample_num, neighbor_table):
    """根据源节点采样指定数量的邻居节点,注意使用的是有放回的采样;
    某个节点的邻居节点数量少于采样数量时,采样结果出现重复的节点
    
    Arguments:
        src_nodes {list, ndarray} -- 源节点列表
        sample_num {int} -- 需要采样的节点数
        neighbor_table {dict} -- 节点到其邻居节点的映射表
    
    Returns:
        np.ndarray -- 采样结果构成的列表
    """
    results = []
    for sid in src_nodes:
        # 从节点的邻居中进行有放回地进行采样
        res = np.random.choice(neighbor_table[sid], size=(sample_num, ))
        results.append(res)
    return np.asarray(results).flatten()

def multihop_sampling(src_nodes, sample_nums, neighbor_table):
    """根据源节点进行多阶采样
    
    Arguments:
        src_nodes {list, np.ndarray} -- 源节点id
        sample_nums {list of int} -- 每一阶需要采样的个数
        neighbor_table {dict} -- 节点到其邻居节点的映射
    
    Returns:
        [list of ndarray] -- 每一阶采样的结果
    """
    sampling_result = [src_nodes]
    for k, hopk_num in enumerate(sample_nums):
        hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table)
        sampling_result.append(hopk_result)
    return sampling_result

然后聚合

聚合通过神经网络来实现。在定义的forward函数中,输入neighbor_feature表示需要聚合的邻居节点的特征,它的维度为 Nsrc × Nneighbor × Din​,其中 Nsrc​表示源节点的数量, Nneighbor​表示邻居节点的数量, Din​​表示输入特征的维度。将这些邻居节点的特征通过一个线性变换得到隐藏层特征,进而进行求和、均值、最大值等聚合操作,得到Nsrc × Nneighbor × Din​​维的输出。

代码语言:javascript
代码运行次数:0
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class NeighborAggregator(nn.Module):
    def __init__(self, input_dim, output_dim, 
                 use_bias=False, aggr_method="mean"):
        """聚合节点邻居
        Args:
            input_dim: 输入特征的维度
            output_dim: 输出特征的维度
            use_bias: 是否使用偏置 (default: {False})
            aggr_method: 邻居聚合方式 (default: {mean})
        """
        super(NeighborAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.aggr_method = aggr_method
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, neighbor_feature):
        if self.aggr_method == "mean":
            aggr_neighbor = neighbor_feature.mean(dim=1)
        elif self.aggr_method == "sum":
            aggr_neighbor = neighbor_feature.sum(dim=1)
        elif self.aggr_method == "max":
            aggr_neighbor = neighbor_feature.max(dim=1)
        else:
            raise ValueError("Unknown aggr type, expected sum, max, or mean, but got {}"
                             .format(self.aggr_method))
        
        neighbor_hidden = torch.matmul(aggr_neighbor, self.weight)
        if self.use_bias:
            neighbor_hidden += self.bias

        return neighbor_hidden

    def extra_repr(self):
        return 'in_features={}, out_features={}, aggr_method={}'.format(
            self.input_dim, self.output_dim, self.aggr_method)

GraphSAGE模型构建(net.py)

代码语言:javascript
代码运行次数:0
复制
class SageGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                 activation=F.relu,
                 aggr_neighbor_method="mean",
                 aggr_hidden_method="sum"):
        """SageGCN层定义
        Args:
            input_dim: 输入特征的维度
            hidden_dim: 隐层特征的维度,
                当aggr_hidden_method=sum, 输出维度为hidden_dim
                当aggr_hidden_method=concat, 输出维度为hidden_dim*2
            activation: 激活函数
            aggr_neighbor_method: 邻居特征聚合方法,["mean", "sum", "max"]
            aggr_hidden_method: 节点特征的更新方法,["sum", "concat"]
        """
        super(SageGCN, self).__init__()
        assert aggr_neighbor_method in ["mean", "sum", "max"]
        assert aggr_hidden_method in ["sum", "concat"]
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.aggr_neighbor_method = aggr_neighbor_method
        self.aggr_hidden_method = aggr_hidden_method
        self.activation = activation
        self.aggregator = NeighborAggregator(input_dim, hidden_dim,
                                             aggr_method=aggr_neighbor_method)
        self.b = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.b)

    def forward(self, src_node_features, neighbor_node_features):
        neighbor_hidden = self.aggregator(neighbor_node_features)
        self_hidden = torch.matmul(src_node_features, self.b)
        
        if self.aggr_hidden_method == "sum":
            hidden = self_hidden + neighbor_hidden
        elif self.aggr_hidden_method == "concat":
            hidden = torch.cat([self_hidden, neighbor_hidden], dim=1)
        else:
            raise ValueError("Expected sum or concat, got {}"
                             .format(self.aggr_hidden))
        if self.activation:
            return self.activation(hidden)
        else:
            return hidden

    def extra_repr(self):
        output_dim = self.hidden_dim if self.aggr_hidden_method == "sum" else self.hidden_dim * 2
        return 'in_features={}, out_features={}, aggr_hidden_method={}'.format(
            self.input_dim, output_dim, self.aggr_hidden_method)

基于前面定义的采样和节点特征更新方式,就可以实现计算节点嵌入的方法。下面定义了一个两层的模型,隐藏层节点数为64,假设每阶采样点数都为10,那么计算中心节点的输出可以通过以下代码实现。其中前向传播时传入的参数node_features_list是一个列表,其中第0个元素代表源节点的特征,其后的元素表示每阶采样得到的节点特征。

代码语言:javascript
代码运行次数:0
复制
def __init__(self, input_dim, hidden_dim,
                 num_neighbors_list):
        super(GraphSage, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_neighbors_list = num_neighbors_list
        self.num_layers = len(num_neighbors_list)
        self.gcn = nn.ModuleList()
        self.gcn.append(SageGCN(input_dim, hidden_dim[0]))
        for index in range(0, len(hidden_dim) - 2):
            self.gcn.append(SageGCN(hidden_dim[index], hidden_dim[index+1]))
        self.gcn.append(SageGCN(hidden_dim[-2], hidden_dim[-1], activation=None))

    def forward(self, node_features_list):
        hidden = node_features_list
        for l in range(self.num_layers):
            next_hidden = []
            gcn = self.gcn[l]
            for hop in range(self.num_layers - l):
                src_node_features = hidden[hop]
                src_node_num = len(src_node_features)
                neighbor_node_features = hidden[hop + 1] \
                    .view((src_node_num, self.num_neighbors_list[hop], -1))
                h = gcn(src_node_features, neighbor_node_features)
                next_hidden.append(h)
            hidden = next_hidden
        return hidden[0]

    def extra_repr(self):
        return 'in_features={}, num_neighbors_list={}'.format(
            self.input_dim, self.num_neighbors_list
        )

4、数据处理(data.py)

代码语言:javascript
代码运行次数:0
复制
import os
import os.path as osp
import pickle
import itertools
import scipy.sparse as sp
import urllib
from collections import namedtuple
import numpy as np

Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
                           'train_mask', 'val_mask', 'test_mask'])

class CoraData(object):
    download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"
    filenames = ["ind.cora.{}".format(name) for name in
                 ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]

    def __init__(self, data_root="cora", rebuild=False):
        """Cora数据,包括数据下载,处理,加载等功能
        当数据的缓存文件存在时,将使用缓存文件,否则将下载、进行处理,并缓存到磁盘
        处理之后的数据可以通过属性 .data 获得,它将返回一个数据对象,包括如下几部分:
            * x: 节点的特征,维度为 2708 * 1433,类型为 np.ndarray
            * y: 节点的标签,总共包括7个类别,类型为 np.ndarray
            * adjacency_dict: 邻接信息,,类型为 dict
            * train_mask: 训练集掩码向量,维度为 2708,当节点属于训练集时,相应位置为True,否则False
            * val_mask: 验证集掩码向量,维度为 2708,当节点属于验证集时,相应位置为True,否则False
            * test_mask: 测试集掩码向量,维度为 2708,当节点属于测试集时,相应位置为True,否则False
        Args:
        -------
            data_root: string, optional
                存放数据的目录,原始数据路径: {data_root}/raw
                缓存数据路径: {data_root}/processed_cora.pkl
            rebuild: boolean, optional
                是否需要重新构建数据集,当设为True时,如果存在缓存数据也会重建数据
        """
        self.data_root = data_root
        save_file = osp.join(self.data_root, "processed_cora.pkl")
        if osp.exists(save_file) and not rebuild:
            print("Using Cached file: {}".format(save_file))
            self._data = pickle.load(open(save_file, "rb"))
        else:
            self.maybe_download()
            self._data = self.process_data()
            with open(save_file, "wb") as f:
                pickle.dump(self.data, f)
            print("Cached file: {}".format(save_file))

    @property
    def data(self):
        """返回Data数据对象,包括x, y, adjacency, train_mask, val_mask, test_mask"""
        return self._data

    def process_data(self):
        """
        处理数据,得到节点特征和标签,邻接矩阵,训练集、验证集以及测试集
        引用自:https://github.com/rusty1s/pytorch_geometric
        """
        print("Process data ...")
        _, tx, allx, y, ty, ally, graph, test_index = [self.read_data(
            osp.join(self.data_root, "raw", name)) for name in self.filenames]
        train_index = np.arange(y.shape[0])
        val_index = np.arange(y.shape[0], y.shape[0] + 500)
        sorted_test_index = sorted(test_index)

        x = np.concatenate((allx, tx), axis=0)
        y = np.concatenate((ally, ty), axis=0).argmax(axis=1)

        x[test_index] = x[sorted_test_index]
        y[test_index] = y[sorted_test_index]
        num_nodes = x.shape[0]

        train_mask = np.zeros(num_nodes, dtype=np.bool)
        val_mask = np.zeros(num_nodes, dtype=np.bool)
        test_mask = np.zeros(num_nodes, dtype=np.bool)
        train_mask[train_index] = True
        val_mask[val_index] = True
        test_mask[test_index] = True
        adjacency_dict = graph
        print("Node's feature shape: ", x.shape)
        print("Node's label shape: ", y.shape)
        print("Adjacency's shape: ", len(adjacency_dict))
        print("Number of training nodes: ", train_mask.sum())
        print("Number of validation nodes: ", val_mask.sum())
        print("Number of test nodes: ", test_mask.sum())

        return Data(x=x, y=y, adjacency_dict=adjacency_dict,
                    train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

    def maybe_download(self):
        save_path = os.path.join(self.data_root, "raw")
        for name in self.filenames:
            if not osp.exists(osp.join(save_path, name)):
                self.download_data(
                    "{}/{}".format(self.download_url, name), save_path)

    @staticmethod
    def build_adjacency(adj_dict):
        """根据邻接表创建邻接矩阵"""
        edge_index = []
        num_nodes = len(adj_dict)
        for src, dst in adj_dict.items():
            edge_index.extend([src, v] for v in dst)
            edge_index.extend([v, src] for v in dst)
        # 去除重复的边
        edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index)))
        edge_index = np.asarray(edge_index)
        adjacency = sp.coo_matrix((np.ones(len(edge_index)),
                                   (edge_index[:, 0], edge_index[:, 1])),
                                  shape=(num_nodes, num_nodes), dtype="float32")
        return adjacency

    @staticmethod
    def read_data(path):
        """使用不同的方式读取原始数据以进一步处理"""
        name = osp.basename(path)
        if name == "ind.cora.test.index":
            out = np.genfromtxt(path, dtype="int64")
            return out
        else:
            out = pickle.load(open(path, "rb"), encoding="latin1")
            out = out.toarray() if hasattr(out, "toarray") else out
            return out

    @staticmethod
    def download_data(url, save_path):
        """数据下载工具,当原始数据不存在时将会进行下载"""
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        data = urllib.request.urlopen(url)
        filename = os.path.split(url)[-1]

        with open(os.path.join(save_path, filename), 'wb') as f:
            f.write(data.read())

        return True

主脚本

代码语言:javascript
代码运行次数:0
复制
import torch

import numpy as np
import torch.nn as nn
import torch.optim as optim
from net import GraphSage
from data import CoraData
from sampling import multihop_sampling

from collections import namedtuple

###数据准备
Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
                           'train_mask', 'val_mask', 'test_mask'])

data = CoraData().data
x = data.x / data.x.sum(1, keepdims=True)  # 归一化数据,使得每一行和为1

train_index = np.where(data.train_mask)[0]
train_label = data.y[train_index]
test_index = np.where(data.test_mask)[0]

####模型初始化
INPUT_DIM = 1433    # 输入维度
# Note: 采样的邻居阶数需要与GCN的层数保持一致
HIDDEN_DIM = [128, 7]   # 隐藏单元节点数
NUM_NEIGHBORS_LIST = [10, 10]   # 每阶采样邻居的节点数
assert len(HIDDEN_DIM) == len(NUM_NEIGHBORS_LIST)
BTACH_SIZE = 16     # 批处理大小
EPOCHS = 20
NUM_BATCH_PER_EPOCH = 20    # 每个epoch循环的批次数
LEARNING_RATE = 0.01    # 学习率
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = GraphSage(input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM,
                  num_neighbors_list=NUM_NEIGHBORS_LIST).to(DEVICE)
print(model)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)

####模型训练和测试
def train():
    model.train()
    for e in range(EPOCHS):
        for batch in range(NUM_BATCH_PER_EPOCH):
            batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,))
            batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
            batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
            batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
            batch_train_logits = model(batch_sampling_x)
            loss = criterion(batch_train_logits, batch_src_label)
            optimizer.zero_grad()
            loss.backward()  # 反向传播计算参数的梯度
            optimizer.step()  # 使用优化方法进行梯度更新
            print("Epoch {:03d} Batch {:03d} Loss: {:.4f}".format(e, batch, loss.item()))
        test()


def test():
    model.eval()
    with torch.no_grad():
        test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
        test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
        test_logits = model(test_x)
        test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
        predict_y = test_logits.max(1)[1]
        accuarcy = torch.eq(predict_y, test_label).float().mean().item()
        print("Test Accuracy: ", accuarcy)

train()

当然了,方法有很多,我们这里仅是冰山一角。

生活很好,有你更好

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 作者,Evil Genius
  • 开工第一天,岁数又加一。
  • 今天我们分享的内容是,GraphSAGE。
  • 在昨天的文章文献分享---空间转录组学鉴定与肺纤维化远端肺重构相关的分子生态位失调(Xenium + HD)中,我们可以看到文中对空间转录组的分子聚类采用的是
  • 从文章的信息可以获知,1、GraphSAGE是一种图神经网络;2、GraphSAGE,基于转录组数据的空间位置训练了一个图神经网络模型以聚合局部邻域信息并定义一个嵌入空间,该空间为数据集中的所有单个转录组提供了新的表示。然后,应用高斯混合模型在嵌入空间中聚类转录组,并识别生态位,使用共识方法将细胞分配到这些生态位中。
  • 那么我们需要首先知道什么是图神经网络GraphSAGE。
  • GraphSAGE(Graph Sample and Aggregated)是一种用于图节点嵌入学习的图神经网络模型。它通过采样和聚合的方式,将邻居节点的信息聚合到目标节点上,从而学习节点的表示向量。
  • GraphSAGE的核心思想是从目标节点的邻居节点中采样一部分节点,然后通过聚合操作将邻居节点的特征信息整合到目标节点上。这样一方面减少了计算复杂度,另一方面也保留了图结构中的信息。GraphSAGE模型包含以下几个重要的步骤:
  • 通过多层的编码器和聚合器,GraphSAGE能够逐渐聚合更多层次的邻居节点信息,并且逐渐扩大目标节点对邻居节点的感知范围。最终,每个节点都能够获得一个表示其自身和周围结构的嵌入向量,该向量可以用于下游的节点分类、链接预测等任务。
  • GraphSAGE在图节点嵌入学习任务中具有较好的性能,能够有效地学习图结构中的节点特征。它可以用于社交网络分析、推荐系统、图像分析等领域,对于挖掘和分析图结构数据具有重要的应用价值。
  • 来到这个地方,我们再扩展一下
  • 在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。
  • GCN虽然能提取图中顶点的embedding,但是存在一些问题:
  • GCN的基本思想: 把一个节点在图中的高纬度邻接信息降维到一个低维的向量表示。
  • GCN的优点: 可以捕捉graph的全局信息,从而很好地表示node的特征。
  • GCN的缺点: 第一,GCN需要将整个图放到内存和显存,这将非常耗内存和显存,处理不了大图;第二,GCN在训练时需要知道整个图的结构信息(包括待预测的节点), 这在现实某些任务中也不能实现(比如用今天训练的图模型预测明天的数据,那么明天的节点是拿不到的)。
  • 为了解决GCN的两个缺点问题,GraphSAGE被提了出来。在介绍GraphSAGE之前,先介绍一下Inductive learning和Transductive learning。注意到图数据和其他类型数据的不同,图数据中的每一个节点可以通过边的关系利用其他节点的信息。这就导致一个问题,GCN输入了整个图,训练节点收集邻居节点信息的时候,用到了测试和验证集的样本,我们把这个称为Transductive learning。然而,我们所处理的大多数的机器学习问题都是Inductive learning,因为我们刻意的将样本集分为训练/验证/测试,并且训练的时候只用训练样本。这样对图来说有个好处,可以处理图中新来的节点,可以利用已知节点的信息为未知节点生成embedding,GraphSAGE就是这么干的。
  • GraphSAGE是一个Inductive Learning框架,具体实现中,训练时它仅仅保留训练样本到训练样本的边,然后包含Sample和Aggregate两大步骤,Sample是指如何对邻居的个数进行采样,Aggregate是指拿到邻居节点的embedding之后如何汇聚这些embedding以更新自己的embedding信息。下图展示了GraphSAGE学习的一个过程:
  • 简单的逻辑过程如下
  • 第一步,对邻居采样;
  • 第二步,采样后的邻居embedding传到节点上来,并使用一个聚合函数聚合这些邻居信息以更新节点的embedding;
  • 第三步,根据更新后的embedding预测节点的标签。
  • 算法假设模型已经过训练,参数是固定的。
  • 算法是,在每次迭代(或搜索深度),顶点从它们的局部邻居聚合信息,并且随着这个过程的迭代,顶点会从越来越远的地方获得信息。
  • 到这里还有多少没有阵亡的小伙伴?
  • 算法描述了在整个图上生成embedding的过程,其中
  • 首先,(line1)算法首先初始化输入的图中所有节点的特征向量,(line3)对于每个节点 ,拿到它采样后的邻居节点 后,(line4)利用聚合函数聚合邻居节点的信息,(line5)并结合自身embedding通过一个非线性变换更新自身的embedding表示。
  • 注意到算法里面的 ,它是指聚合器的数量,也是指权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。网络的层数可以理解为需要最大访问的邻居的跳数(hops),比如在图1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。为了更新红色节点,首先在第一层(k=1),会将蓝色节点的信息聚合到红色解节点上,将绿色节点的信息聚合到蓝色节点上。在第二层(k=2)红色节点的embedding被再次更新,不过这次用到的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息,也就是两跳信息。
  • 最后来看看示例代码
  • 首先采样
  • 然后聚合
  • GraphSAGE模型构建(net.py)
  • 基于前面定义的采样和节点特征更新方式,就可以实现计算节点嵌入的方法。下面定义了一个两层的模型,隐藏层节点数为64,假设每阶采样点数都为10,那么计算中心节点的输出可以通过以下代码实现。其中前向传播时传入的参数node_features_list是一个列表,其中第0个元素代表源节点的特征,其后的元素表示每阶采样得到的节点特征。
  • 4、数据处理(data.py)
  • 主脚本
  • 当然了,方法有很多,我们这里仅是冰山一角。
  • 生活很好,有你更好
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档