论文解读(GraphSAGE)《Inductive Representation Learning on Large Graphs》
《Inductive Representation Learning on Large Graphs》
论文标题:Inductive Representation Learning on Large Graphs
论文作者: William L. Hamilton (wleif@stanford.edu), Rex Ying (rexying@stanford.edu)
论文来源:NIPS 2017
论文链接:chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.org/pdf/1706.02216.pdf
论文代码:https://github.com/williamleif/GraphSAGE
2 介绍及相关工作
Transductive Learning
假设要测试的节点和训练的节点在一个图中,并且训练过程中图结构中的所有节点都被考虑进去。它们只能得到已经包含在训练过程中的节点嵌入,对于训练过程中没有出现过的未知节点则束手无策。由于它们在一个固定的图上直接生成最终的节点嵌入,如果这个图的结构稍后有所改变,就需要重新训练。
直推式学习已经预先观察了所有数据,含训练和测试数据集。 从已经观察到的数据集中学习,然后预测测试数据集的标签。 即过程会利用这些不知道数据标签的测试集数据的模式和其他信息。还有一个区别是,一旦有新的节点出现,直推式学习需要重新训练模型。
Inductive Learning
主要观点是:节点的嵌入可以通过一个共同的聚合邻居节点信息的函数得到,在训练时只要得到这个聚合函数,就可以将其泛化到未知的节点上。
Factorization-based embedding approaches- 一些使用随机游走统计和基于矩阵分解的学习目标的节点嵌入方法。
- 这些嵌入算法中的大多数直接为单个节点训练节点嵌入。 因此需要昂贵的额外训练(例如,通过随机梯度下降)来对新节点进行预测。
Supervised learning over graphs
- 基于 Graph kernel 的方法,其中图的特征向量来自不同的图内核。
- 如果半监督时带有 label 的节点过少,GCN 的性能会有比较严重的下降;
- 浅层的 GCN 网络不能大范围地传播 label 信息 (层级越深,节点的感受野越大);
- 深层的 GCN 网络会导致过度平滑 (smooth) 的问题;
本文提出的 GraphSAGE(Inductive Method) 可以利用所有图中存在的结构特征(如:节点度,邻居信息),去推测 Unseen Node 的节点 Embeeding。
- 先对邻居随机采样,降低计算复杂度(Figure 1 :一跳邻居采样数=3,二跳邻居采样数=5)
- 生成目标节点 Emebedding:先聚合2跳邻居特征,生成一跳邻居 Embedding,再聚合一跳邻居 Embedding,生成目标节点 Embedding,从而获得二跳邻居信息。
- 将 Embedding 作为全连接层的输入,预测目标节点的标签。
3 GraphSAGE Method
GraphSAGE 的核心思想:不是试图学习一个图上所有 Node Embedding,而是学习一个为每个 Node 产生 Embedding 的映射。
3.1 Embedding generation algorithm
该部分假设模型已经被训练过了,并且参数是固定的。
我们假设我们已经学习了 $K$ 个聚合器函数的参数,
$\text { AGGREGATE } \left._{k}, \forall k \in\{1, \ldots, K\}\right)$
用模型的不同层或“搜索深度”之间传播信息。
步骤:
GraphSAGE 的前向传播算法如下,前向传播描述了如何使用聚合函数对节点的邻居信息进行聚合,从而生成节点 Embedding:
- $ \mathcal{G}=(\mathcal{V}, \mathcal{E})$ 表示一个图;
- $ K$ 是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,因为每增加一层,可以聚合更远的一层邻居的信息;
- $ x_{v}$,$\forall v \in \mathcal{V}$ 表示节点 $v$ 的特征向量,并且作为输入;
- $ \left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}$ 表示在 $k-1$ 层中节点 $v$ 的邻居节点 $u$ 的 Embedding;
- $ \mathbf{h}_{\mathcal{N}(v)}^{k}$ 表示在第 $k$ 层,节点 $v$ 的所有邻居节点的特征表示;
- $ \mathbf{h}_{v}^{k}$,$\forall v \in \mathcal{V}$ 表示在第 $k$ 层,节点 $v$ 的特征表示;
- $ \mathcal{N}(v) $ 定义为从集合 $\{u \in v:(u, \mathcal{V}) \in \mathcal{E}\}$ 中的固定 $size$ 的均匀取出,即 GraphSAGE 中每一层的节点邻居都是是从上一层网络采样的,并不是所有邻居参与,并且采样后的邻居 $size$ 是固定的;
3.2 Learning the parameters of GraphSAGE
损失函数分为基于图的无监督损失和有监督损失。
- 基于图的无监督损失:目标是使节点 $u$ 与 “邻居” $v$ 的 Embedding 相似,与无边相连的节点 $v_n$ 不相似。
$J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)$
- $z_{u}$ 为节点通过 GraphSAGE 生成的 Embedding ;
- 节点 $v$ 是节点 $u$ 结果固定长度的 Random walk 到达的"邻居";
- $v_{n} \sim P_{n}(u)$ 表示负采样:节点 $v_{n}$ 是从节点 $u$ 的负采样分布 $P_{n}$ 采样的, $Q$ 为采样样本数;
- Embedding 之间的相似度通过向量点积计算得到;
- 基于图的有监督损失:无监督损失函数的设定来学习节点 Embedding 可以供下游多个任务使用,若仅使用在特定某个任务上,则可以替代上述损失函数符合特定任务目标,如交叉熵。
3.3 Aggregator Architectures
算法可以应用于任意顺序的节点表示向量(即:排列不变性),所以聚集函数(aggregation function)应该是对称的。
排列不变性(permutation invariance):指输入的顺序改变不会影响输出的值。
这里采用 Mean aggregator 、LSTM aggregator 、Pooling aggregator。
- Mean aggregator
Mean aggregator 将目标顶点和邻居顶点的第 $k?1$ 层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第 $k$ 层表示向量。
GCN 的 inductive 变形:
$h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right.$
Convolutional aggregator $\begin{array}{c}h_{N(v)}^{k}=\operatorname{mean}\left(\left\{h_{u}^{k-1}, u \in N(v)\right\}\right) \\h_{v}^{k}=\sigma\left(W^{k} \cdot C O N C A T\left(h_{v}^{k-1}, h_{N(u)}^{k}\right)\right)\end{array}$- LSTM聚合:LSTM函数不符合 "排列不变性" 的性质,需要先对邻居随机排序,然后将随机的邻居序列 Embedding $ \left\{x_{t}, t \in N(v)\right\}$ 作为 LSTM 输入。
-
Pooling 聚合:
它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的 Embedding 向量进行一次非线性变换,之后进行一次 Pooling 操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第 $k$ 层表示向量。
一个element-wise max pooling操作应用在邻居集合上来聚合信息:
$\text { AGGREGATE }_{k}^{\mathrm{pool}}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right)$
$\mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right)$
其中
- $max$ 表示 $element-wise$ 最大值操作, 取每个特征的最大值
- $\sigma$ 是非线性激活函数
- 所有相邻节点的向量共享权重, 先经过一个非线性全连接层, 然后做 $max-pooling$
- 按维度应用 $max / mean \quad pooling$,可以捕获邻居集上在某一个维度的突出的综合的表现。
4 Experiments
在三个基准任务上测试了GraphSAGE的性能。
datasets
-
- Citation 论文引用网络(节点分类)
- Reddit 帖子论坛 (节点分类)
- PPI 蛋白质网络 (graph分类)
four baselines
-
- Random classifer,随机分类器
- Raw features,手工特征(非图特征)
- Deepwalk(图拓扑特征)
- DeepWalk + features, deepwalk+手工特征
基于图的无监督损失
$J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)$
基于图的有监督损失
交叉熵
实验设置
- $K=2$,聚合两跳内邻居特征
- $S_1=25,S_2=10$: 对一跳邻居抽样25个,二跳邻居抽样10个
- RELU 激活单元
- Adam 优化器(仅 DeepWalk 使用 SGD )
- 文中所有的模型都是用 TensorFlow 实现
- 对每个节点进行步长为 5 的 50 次随机游走
- 负采样参考 Word2vec,按平滑 degree 进行,对每个节点采样 20 个
- 保证公平性:所有版本都采用相同的minibatch迭代器、损失函数、邻居采样器
- 实验测试了根据式1的损失函数训练的GraphSAGE的各种变体,还有在分类交叉熵损失上训练的可监督变体
- 对于Reddit和citation数据集,使用”online”的方式来训练DeepWalk
- 在多图情况下,不能使用DeepWalk,因为通过DeepWalk在不同不相交的图上运行后生成的embedding空间对它们彼此说可能是arbitrarily rotated的。
实验结果1:分类准确率
结论:
-
- GraphSAGE的性能显著优于baseline方法。
- 三个数据集显示:一般是 LSTM 或 pooling 效果比较好,有监督都比无监督好。
- LSTM 是为有序数据而不是无序集设计的,但是基于 LSTM 的聚合器显示了强大的性能。
- 可以看到无监督 GraphSAGE 的性能与完全监督的版本相比具有相当的竞争力,这表明文中的框架可以在不进行特定于任务的微调( task-specific fine-tuning )的情况下实现强大的性能。
实验结果2:Timing experiments on Reddit data
- 计算时间:下图A中GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)
- 邻居抽样数量:上图B中邻居抽样数量递增,边际收益递减(F1),但计算时间也变大。 平衡F1和计算时间,将S1设为25。
- 聚合 K 跳内信息:在 GraphSAGE, K=2 相比 K=1 有10-15%的提升;但将K设置超过2,边际效果上只有 0-5% 的提升,但是计算时间却变大了10-100倍。
『总结不易,加个关注呗!』