GNN

图神经网络简介

GNN

Posted by Roger on April 9, 2022

本文翻译总结自博文,感兴趣的建议自己看原文

介绍

  真实世界中很多物体是由它们与其它事物的联系来给出定义的,一组对象以及它们相互之间的联系可以被很自然地用图来表征。作用于图数据上的神经网络叫做图神经网络(GNN,Graph Neural Network),随着GNN表达能力的不断提高,我们已经开始看到它在实际生活中发挥作用:如抗生素发现、物理仿真、虚假新闻检测、流量预测及推荐系统等。

  描述一张图需要用到三种不同层面的属性:

  1. 节点(vertex / node)属性:如节点标识(identity),邻居数量等
  2. 边(edge / link)属性:如边标识,边权重等
  3. 全局(global / master node)属性:如节点数目,最长路径等

  根据边的方向性,图可以分为两类:

  1. 无(双)向边(Undirected edge)
  2. 有向边(Directed edge)

    图数据

      除了一些社交网络等直观的可以表达为图的数据,很多别的数据也可以表示为图,如图片和文本。尽管有悖直觉,但通过将图像和文本视为图,可以了解更多关于它们的对称性和结构的信息,并建立一种直觉,有助于理解其他不太像网格的图数据。

    图片作为图数据

      一般我们认为图片是具有多个通道的矩形栅格,以数组(array)的形式来表示。另一种方法是将图片看做有着规则结构的图,其中每个像素代表一个节点,不同节点之间通过边与相邻像素连接。每个非边界上的像素点都有8个邻居,因此存储在每个节点的信息是一个三维向量,分别对应RGB三个通道值。
      一种可视化图连接性的方式是邻接矩阵(adjacency matrix)。以一张5*5的单通道图片为例,总共有$5\times5=25$个像素,可以构造一个$25\times25$的邻接矩阵并填充两个节点共享一条边的单元。

    文本作为图数据

      通过给每个字符/单词/token赋予一个数值索引,从而将文本用一串索引表示。这一方法创建了一个简单的有向图,其中每个字符/索引是一个节点(node)并通过一条边连接到后一个节点。

【注】实际中,并不会真的采用上述两种方式编码图片和文本。因为图片和文本都具有非常规律的结构,这样表示会带来大量的冗余:图片在其邻接矩阵中有一个带状结构,因为所有节点(像素)都连接在一个网格中。文本的邻接矩阵只是一条对角线,因为每个单词只与前一个单词连接,并连接到一个单词。

异构结构作为图数据

现实中很多数据是异构结构(heterogeneously structured)的。此时,每个节点的邻居数量可能都是不同的,这样的数据很难以图以外的方式表征。现实中这样的例子有:

  • 分子作为图
    分子是由原子和电子构成的,所有的粒子是相互作用的,但当一对原子以稳定的距离存在时,我们说它们共享一个共价键。不同的源自对和共价键有不同的距离(如单键和双键)。这种3D结构很容易抽象为一个图,其中节点是原子而边则是共价键。
  • 社交网络作为图 社交网络是研究人们、机构和组织的集体行为模式的工具。我们可以构建由个体作为节点,相互之间的关系作为边的图。
  • 论文引用网络作为图
    可以将论文看做节点,每个有向边代表文章之间的引用。此外还可以向每个节点添加进入节点的每篇文章的信息,比如摘要的embedding。
  • 其它
    在计算机视觉中,我们有时想要标记视觉场景中的对象。然后,我们可以通过将这些对象视为节点并将它们的关系视为边来构建图形。机器学习模型,编程代码以及数学方程也可以被解释为图,图中变量作为节点,边代表以这些变量作为输入和输出的运算。

    【注】这里感觉原文表述有问题,深度学习中模型图是以操作作为节点,变量沿边在节点之间流动。

图数据相关的任务

  图上的预测任务一般分为三种类型:图层面、节点层面和边层面。   在图层面任务,我们为整张图预测一些性质;对于节点层面的任务,我们为图中的每个节点预测一些性质;对于边层面任务,我们预测边的形式(比如是否存在)。

图层面任务

  在图层面任务,我们预测整张图的性质。例如,对于一个分子结构图,我们可能想去预测这个分子的气味,或者它是否会与与疾病有关的受体结合。

节点层面任务

  节点级任务与预测图中每个节点的角色有关。节点级预测问题的一个典型例子是 Zach 的空手道俱乐部,预测问题是在争执之后对给定成员是否忠于 Mr. Hi 或 John H 进行分类。在这种情况下,节点与教练或管理者之间的距离与此标签高度相关。
  节点层面的预测问题类比于图片分割问题,在图片分割中我们要给图片中每个像素的角色打标签。文本问题中一个相似的任务是预测句子中每个单词的词性(例如名词、动词、副词等)。

边层面任务

  边层面推理的一个例子是图片场景理解。除了识别图片中的物体,深度学习模型还可以被用于预测它们之间的关系。我们可以将这一任务看做一个边层面的分类任务:给定代表图片中物体的节点,我们希望预测哪些节点共享一条边或者这条边的值是多少。如果我们想发现实体之间的联系,我们可以将图看做全部相互连接的,然后基于预测的结果来对图进行修剪以获得一个稀疏的图。

如何用神经网络解决图任务

  首要问题是思考如何让图的表示与神经网络结构兼容。机器学习模型通常采用规整数组作为输入。图具有多达四种类型的信息可以用于预测:节点、边、全局上下文和连通性(connectivity)。前三种是相对直观的:比如,使用节点,我们可以通过为每个节点分配一个索引$i$并来构建一个节点特征矩阵$N$并将节点$node_i$的特征存储到该矩阵中。虽然这些矩阵具有可变数量的例子,但它们不需任何特殊技术就能处理。
  相比之下,表示图的连通性要更加复杂一些。最明显的方法是使用邻接矩阵,因为它更容易张量化。然而这一方法有一些缺点:当节点的数量特别多时,矩阵会变得很大;每个节点的连接数变化很大,通常会产生很稀疏的矩阵,浪费存储空间。
  另一个问题是,有许多邻接矩阵可以编码相同的连通性(比如不同的分配node index的方式),并且不能保证这些不同的矩阵会在深度神经网络中产生相同的结果(也就是说,它们不是排列不变(permutation invariant)的)。
  一个更内存友好的表达稀疏矩阵连的方法是邻接列表。邻接列表的第k项是一个元组$(i,j)$,代表了节点$n_i$和$n_j$之间的边$e_k$。因为我们知道边的数量要远小于邻接矩阵的规模($n_{node}^2$),这样做避免了图中不连通部分的计算和存储。

图神经网络(GNN)

  GNN是对图的所有属性(节点、边、全局上下文)的可优化转换,它保留了图的对称性(排列不变性,permutation invariance)。GNN采用的“图进,图出”的架构意味着这些模型采用一张加载了信息到节点、边和全局上下文的图作为输入,然后在不改变输入图的连通性的条件下渐进地对这些embedding做变换。

最简单的GNN

  从最简单的GNN开始,该GNN可以为图的所有属性(节点、边、全局上下文)学习到新的embedding,但暂时不使用图的连通性。
  该GNN为图中的每个组件使用一个单独的MLP(或其它的可微模型),我们将其称之为GNN的一个层。对于每个节点向量,我们让其经过一个MLP以得到一个学习后的节点向量;对于图的每条边也学习一个embedding;对于全局上下文,为整张图学习一个embedding。与常见的神经网络模块或层相同,我们可以堆叠许多GNN层形成一个更大的模型。
  因为该GNN不更新输入图的连通性,我们可以将该GNN的输出用与输入相同的邻接列表表示,并且输出的图和输入图有着相同数量的特征向量。不同的是,因为GNN对于每个节点、边、全局上下文都做过更新,输出的图的embedding与输入图的不同。

通过汇集信息进行GNN预测

  在有了简单的GNN后,如何根据这些信息做预测呢?以二分类为例,如果任务是基于节点来做二分预测并且图已经包含了节点信息,那么一个直接的想法就是对每个节点embedding添加加一个线性分类器。但如果只有边的信息却仍要做节点层面的预测,则需要一个能根据边为节点收集信息的方法来做预测。这一过程可以通过汇集(pooling)来实现,池化过程包含两步:

  1. 对于要汇集的每一项,收集它们的embedding并拼接为一个矩阵
  2. 聚合收集到的embedding,通常是通过求和操作

  如果只有节点层面的特征并且要做边层面的预测,可以使用pooling来传递/路由信息(如相邻边的信息)到需要的节点。类似的,可以汇集边层面的信息来做节点层面的预测。如果我们只有节点/边层面的特征并且需要做全局信息预测,则需要汇集所有可达的节点/边信息并做聚合。这一操作类似于CNN中的Global Average Pooling。Pooling操作是可用于构建更加复杂GNN模型的模块,如果我们有新的图属性,只需要定义如何从一个属性到另一属性传递信息。

在图的不同组件之间传递信息

  为了让我们学到的embedding包含图连通性的信息,我们可以在GNN层中通过使用pooling来做更加复杂的预测。Pooling可以通过消息传递来实现(message passing),相邻的节点/边可以交换信息并影响其它节点/边的embedding。消息传递的步骤有三步:

  1. 对于图中每个节点,汇集所有邻居节点的embedding/message
  2. 通过聚合函数(如sum)来汇集所有的消息
  3. 所有汇集的信息被传递到一个更新函数,通常是一个神经网络

  Message passing可以被应用到节点或边。本质上,message passing和卷积操作都是聚合并处理相邻单元的信息来更新当前单元的值,区别在于图的单元是一个节点而图片的单元是一个像素。二者的不同之处在于节点的邻居数量是不固定的,而图片中的像素都有(至少对图片中间的像素)固定数量的相邻单元。

学习边的表征

  我们可以用与之前使用相邻节点信息相同的方式合并来自相邻边的信息,首先将边信息合并,使用更新函数对其进行编号,然后存储。然而,存储在图中的节点和边的信息不一定是相同的size/shape。结合二者信息的一种方式是学习一个从边/节点空间到节点/边空间的线性映射,另一种方法是在通过更新函数之前将它们拼接(concatenate)起来。   更新图的哪个属性以及以何种顺序更新它们是设计GNN时需要考虑的一项内容。我们可以选择是否要在更新边embedding之前更新节点embedding,或其它的方法。还可以采取交替更新边和节点的方式,此时我们有四个更新后的表示,它们组合成新的节点和边表示:节点到节点(线性)、边到边(线性)、节点到边(边层)、边到节点(节点层)。

添加全局表征

  目前为止描述过的网络还有一个缺点:即使应用多次message passing,相互之间离得很远的节点之间仍无法有效地传递信息。对于每个节点,如果我们有k层,那么信息最多只会传播k步远。这对于依赖于远距离节点/节点组之间信息的预测是不利的,一个解决方法是让节点都可以接受来自所有的其它节点的信息,但这一方法对于大的图来说是不现实的,因为计算量太大了。
  另一种解决方法是使用图的全局表征,该表征有时被叫做主节点(master node)或上下文向量(context vector)。这一全局上下文向量连接到网络中所有的节点和边,作为它们之间相互传递信息的桥梁,构建了一个将整张图作为整体的表征。
  那么如何利用这些信息呢?以一个节点为例,我们可以考虑来自其邻居节点,相连的边以及全局信息。为了利用所有这些信息源来调整新的节点embedding,简单的方法是拼接。此外,我们还可以将这些特征通过一个线性映射映射到同一个空间然后相加或者添加一个逐feature的调制层(这可以被看做是一种逐feature的注意力机制)。

设计图神经网络

  可以从多个层面来自定义不同的GNN模型:

  1. GNN层数,也叫做深度
  2. 每个属性更新后的维度。更新函数是带有ReLU激活函数以及一个LayerNorm层用来标准化激活函数输出的一层MLP
  3. 用于Pooling的聚合函数:max,mean或sum
  4. 要更新的图属性或message passing的方式:节点,边以及全局表征。可以通过boolean开关来控制这些内容

  GNN性能的下界随着层的增加有时反而降低,这可能是因为:具有较多层数的GNN将以更远的距离传播信息,并可能在许多连续迭代中使其节点表示被“稀释”。一般来说,图属性之间的消息传递越多,模型的性能会越好。

图神经网络的种类

  1. multi-edge graph(或称multigraph),其中一对节点可以共享多种类型的边。这一类型网络用于我们想基于节点类型建模它们之间交互的场景。例如对于社交网络,我们可以基于关系类型(熟人,朋友,家人)来指定边的类型。
  2. nested graph(或hierarchical graph),这种网络对于表示层级信息很有用。例如,我们可以考虑一个分子网络,其中一个节点代表一个分子,如果我们有一种(反应)方式将其中一个分子转化成另外一种时,两个节点之间共享一条边。此时,我们可以通过一个GNN学习分子层面的表征,另一个网络学习网络交互层面的表征,然后在训练过程中交替地学习两个GNN来在嵌套图上学习。
  3. hypergraph,这种图中一条边可以连接到多个节点而不仅仅只是两个。对于给定的一个图,我们可以通过识别节点社群并分配一条与社群内所有节点相连的hyper-edge(汇聚点叫做hyper-node)。超图的边hyper-edge是任意非空顶点集。一个k-超图的超边,它们恰好连接了k个顶点;因此,正常图是 2 超图(因为一条边连接 2 个顶点)。

GNN中的Sampling和Batching

  常规的训练神经网络的方法是用在训练数据中抽取的mini-batch上计算得到的梯度来更新网络参数。这种方法在图数据上遇到了挑战,因为节点和相互连接的边的数量是可变的,意味着我们无法有固定的batch size。
  图数据做batching的主要思想是创建保留了更大图的必要特性的子图。图的采样高度依赖于上下文并涉及到图中部分节点与边的选择。这些操作在某些情形(如文章引用网络)下是可行的,但在其它情况下就不合理(如分子图中,其子图代表的是一个不一样的,更小的分子)。如果我们关注在邻居层面保留结构,一种方法是随机采样均匀数量的节点,即我们的节点集。然后将与节点集距离为k的邻居节点(及相应的边)添加到该子图中。每个邻居可以被认为是一个独立的图且可以将这些子图拼成batch来训练,训练的损失函数可以通过mask来限制只考虑节点集内的内容,因为所有的相邻节点的邻居信息都是不完整的。一种更高效的策略是首先随机采样一个节点,将其邻域扩展到距离k,然后在扩展的集合中选择另一个节点。一旦构造了一定数量的节点、边或子图,这些操作就可以终止。如果情形允许,我们可以通过选择初始节点集然后对恒定数量的节点进行二次采样(采用随机节点采样,随机游走采样,diffusion采样或 Metropolis 算法)来构建恒定数量的邻居。当图太大而无法全部放入内存时,图采样尤其重要。由此也引申出了新的架构和训练策略比如 Cluster-GCN 和 GraphSaint。

归纳偏差(Inductive Biases)

  为了获得更好的性能,我们应该根据数据或任务的特点来选择特定的模型结构,正如对于图片使用 translation invariant 的卷积;对于文本这种顺序关键的数据采用 RNN 来序贯处理;当需要对数据的某部分特别关注时使用注意力机制,如BERT或GPT-3。这些根据数据/任务特点来做专门处理叫做 inductive biases。对于图数据也是如此。具体来说,这意味着在集合上设计转换时:节点或边上的操作顺序应该无关紧要,并且操作应该在可变数量的输入上工作。

不同聚合方式的比较

  在GNN中,因为每个节点的邻居数量是不固定的,聚合信息的操作需要是可微的而且这一操作需要不受节点顺序和节点数量的影响,所以决定如何从相邻节点或边汇聚信息(pooling information)是很关键的一步。
  聚合函数需要对于相似的输入有着相似的聚合结果,常见的 permutation-invariant 的聚合函数有 sum,mean,max,此外还有统计数据(如方差)也可以采用。所有这些运算可有不固定数量的输入并提供相同的输出,而不受输入顺序的影响。
  没有一种 pooling 方式可以始终区分图对,即可能会对两个不同的图给出相同的输出结果。Mean 运算一般用于节点有着数量差别很大的邻居数量或我们需要获取局部邻居的 normalized feature 时。Max运算用于当我们想强调局部邻居某个“突出”的特征时。Sum运算则通过提供局部所有特征的快照,在Mean和Max之间提供了一种平衡,但因为它是未标准化的,所以也可以用来强调异常值。实际应用中,Sum运算用的比较多。
  新的聚合方法如 Principle Neighborhood Aggregation 通过拼接并添加一个缩放函数来考虑多个聚合操作,该缩放函数取决于要聚合的实体的连接程度。同时,也可以设计用于特定领域的聚合函数,比如“四面体手性”(Tetrahedral Chirality)聚合算子。

GCN作为子图函数近似

  一种看待1-degree 邻居查找的k层GCN(或MPNN)的方法是作为一个神经网络,它对大小为 k 的子图的学习嵌入进行操作。当关注一个节点时,经过k层后,更新的节点表征具有最多k步远的有限邻居视野,本质上是一个子图表征。对于边的表征也是同样的情况。所以一个GCN是本质上是收集所有可能的规模为k的子图然后从一个节点或边的有利位置学习向量表征。可能的子图的数量可以组合增长,因此相比于在 GCN 中那样动态地构建它们,从一开始就枚举这些子图可能不太现实。

边及图的对偶

  值得一提的是边预测及节点预测尽管看起来不同,但都可以约简为同一个问题,即:在图 G 上对边的预测任务可以转换为在图 G 的对偶上的节点层面的预测任务。
  为了获取图 G 的对偶,我们可以将节点转换为边(或反过来)。一个图与它的对偶包含着相同的信息,只是表达方式不同。有时根据对偶性质,将问题以另一种形式表示会使得问题更加容易求解,比如在傅里叶空间处理频率。简而言之,为了解决在图 G 上的边分类问题,我们可以考虑在图 G 的对偶上做图卷积(类似于在图 G 上学习边的表征),这一想法是由 Dual-Primal Graph Convolutional Networks 提出的。

矩阵乘法与图上游走的关系

  邻接矩阵和$A_{n_{nodes}\times n_{nodes}}$和与节点的特征矩阵$X_{n_{nodes}\times node_{dim}}$的矩阵乘实现了一个以Sum为聚合函数的message passing。令矩阵$B=AX$,则对于$B$中的任意元素$B_{i,j}$可以表示为$\langle A_{row_i}X_{column_j} \rangle=\sum_{k=1}^nA_{i,k}X_{k,j}$。因为$A_{i,k}$是只有当$node_i$和$node_j$之间有边存在时才存在的二进制项,所以内积本质上是汇聚所有与$node_i$共享边的节点的特征。应该注意的是这个消息传递并不更新节点特征而只是汇聚邻居节点的特征,但这个需求可以通过在矩阵乘之前或之后令$X$通过一个可微的变换(如 MLP)来实现。
  从这点出发,使用邻接列表可以使得我们不需要对$A_{i,j}$中为0的值做求和操作而只关注有效值。此外,这一无需矩阵乘的方法也使得我们可以使用除Sum之外其它的聚合函数。
  我们可以想象通过多次上述操作可以让我们将信息传递到更远的距离。从这个意义上说,矩阵乘是遍历图的一种形式。当我们看邻接矩阵的幂$A^K$时,这种关系也很明显。通过邻接矩阵的幂乘,我们考虑了所有的中间节点。

图注意力网络

  另一种在图各属性之间交流信息的方式是通过注意力。当我们考虑一个节点与它的 1-degree 邻居节点的 Sum 聚合时,也可以考虑使用一个加权和。这一做法带来的挑战是如何以排列不变(permutation invariant)的方式来关联权重。一种方法是使用一个标量打分函数来基于节点对分配权重:$f(node_i, node_j)$。此时这个打分函数可以被理解为衡量邻居节点与中心节点有多相关的函数。权重可以被标准化,例如使用 softmax 函数将大部分权重集中在与任务最相关的邻居上,这一过程是 permutation invariant 的,因为打分是在节点对上进行的。这一概念便是 Graph Attention Network (GAT)以及 Set Transformer 的基础。一个常见的打分函数是内积并且节点通常在通过打分函数前会经过线性映射变换为 query 和 key 来增加打分机制的表达能力。打分得到的权值也可以作为衡量边对于任务重要性的可解释指标。
  Transformer模型可以看做是以token作为节点,并假设所有节点之间具有全连接。而GNN则是假设节点之间的连接是稀疏的。

图的解释与归因

  我们要解释的图概念因上下文的不同而不同。因为图概念的多样性,构建图的解释也有多种方式。GNNExplainer 将这个问题当做提取对任务最重要的子图。而归因技术则为与任务相关的子图分配排好序的重要性值。

生成式建模

  除了在图上学习一个预测模型,我们也关心如何学习一个生成式的模型。通过生成式模型,我们可以通过从一个已经学到的分布中采样或通过从一个起始点补全一张图的方法来生成新的图。一个相关的应用是设计新的药物,其中具有可以某些特定性质的新的分子图被作为治疗某种疾病的候选。
  图生成的一个关键挑战是如何对图的拓扑建模,这一部分有$N_{nodes}^2$项且大小差别可以很大。一个解决方法是用 autoencoer 框架像处理图片一样直接对邻接矩阵建模,把边的出现与否问题当做一个二分类任务来处理。通过只预测已知的边和不存在的边的部分子集,可以避免做多达$N_{nodes}^2$项的预测。graphVAE 学习对邻接矩阵中的连接模式与某些非连接模式进行建模。
  另一种方法是顺序地构建一张图:从一张图开始,不断地对其施加离散的动作如增加或删除节点/边。为了避免对离散的动作估计梯度,可以使用 policy gradient。通过一个自回归模型(如 RNN)或用强化学习中可以实现这一目的。