torch_geometric官方教程
该笔记主要参考:Creating Message Passing Networks — pytorch_geometric documentationMessagePassing功能是图神经网络区别于其他网络的重要部分,其给图神经网络提供了一种信息传递的功能,从某种程度上讲,MessagePassing就是一种特殊的卷积操作,让不同节点的数据相互交流。
将卷积操作泛化到不规则的域被称为聚合或者信息传递。对于$\mathbf{x}_{i}^{(k-1)}\in \mathbb{R}^F$,其代表k-1层的节点 i .$\mathbf{e}_{j,i}\in \mathbb{R}^D$代表从节点 j 到节点 i 的边。信息传递可以表示为如下:
$$
\mathbf{x}_{i}^{(k)}=\gamma^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \phi^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \mathbf{x}_{j}^{(k-1)}, \mathbf{e}_{j, i}\right)\right)
$$
其中$\bigoplus$代表一个可微分的、与顺序无关的函数,例如$\text{max}, \text{mean}, \text{sum}$这些函数,而$\gamma$与$\phi$则代表可微分的函数,例如MLP.

MessagePassing基类

PyG提供了 MessagePassing 基类,其帮助我们创建具有信息传递的图神经网络。该基类可以自动帮助我们进行信息传播,用户只需要定义函数$\phi$(message())和$\gamma$(update())即可。同时,该基类还有聚合框架可以直接使用,例如 aggr = “add”

应用GCN层

GCN层计算如下:
$$
\mathbf{x}_{i}^{(k)}=\sum_{j \in \mathcal{N}(i) \cup{i}} \frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}} \cdot\left(\mathbf{W}^{\top} \cdot \mathbf{x}_{j}^{(k-1)}\right)+\mathbf{b}
$$
上式子中,邻近的节点首先通过一个权重矩阵$\mathbf{W}$转换,然后再被其自身的度数进行归一化,最终加和并加入bias项。上面的式子可以被分为以下六步骤:

  1. 为邻接矩阵加入自环
  2. 线性转换节点特征矩阵
  3. 计算归一化的系数
  4. 归一化节点特征
  5. 聚合
  6. 加入bias
    Steps 1-3 在信息传递前进行,Steps 4-5 可以使用 MessagePassing 基类来进行简单的计算。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.empty(out_channels))

self.reset_parameters()

def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()

def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]

# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

# Step 2: Linearly transform node feature matrix.
x = self.lin(x)

# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

# Step 4-5: Start propagating messages.
out = self.propagate(edge_index, x=x, norm=norm)

# Step 6: Apply a final bias vector.
out = out + self.bias

return out

def message(self, x_j, norm):
# x_j has shape [E, out_channels]

# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j

GCNConvMessagePassing 基类继承,并使用add传播。所有的层逻辑均发生于forward层。这里的norm形状为$[\text{num_edges}, ]$。然后我们使用 propagate() 方法,这个方法将会调用 message() aggregate()update() 。我们将节点嵌入 x 与 归一化系数 norm 传入。
message() 函数中,我们需要使用 norm 归一化 x_j 的邻居节点,这里 x_j 是一个 lifted tensor,即按边展开后的张量,它包含每条边对应的源节点特征。在这里,节点的特征可以自动展开,通过在 x 后面加入 _i 或者 _j。其实每个张量都可以如上展开(如果他们有源头和目标节点的特征)。

1
2
3
4
norm.shape
>>> torch.Size([9949])
x_j
>>> torch.Size([9949, 64])

如上代码所示:9949代表一共有9949个源节点特征,对每个源节点进行归一化之后再按照对应的邻居节点进行聚合。这里的message函数进行的就是对于边权重的处理。
上面代码所见的方法如下解释:
MessagePassing(aggr = "add", flow = "source_to_target", node_dim = -2):定义了聚合框架并使用”mean”方法,信息的流动方向为”source to target”。node_dim代表从哪个坐标轴开始传播。
MessagePassing.propagate(edge_index, size = None, **kwargs):开始进行传播,接受边索引和所有用于构建信息与更新节点嵌入的数据。propagate不仅仅可以用于更新方阵中的信息,也可以在稀疏分配矩阵(如二部图)中进行,仅需要额外传入参数size = (N,M)即可。
MessagePassing.message()flow = "source_to_target"对于节点 i ,构建类似于$\phi$的函数处理每条边 $(j, i)\in \mathcal{E}$。如果flow = "target_to_source"则处理每条边$(i,j)\in \mathcal{E}$。另外,任何传入propagete()的张量均可以采用加上 _i_j 的后缀来转化为对应节点 i 的嵌入。例如:我们选择 i 作为中心节点,j 作为邻居节点,那么 x_i 获得的就是所有边的转换后的目标节点的嵌入按照边的顺序进行排列。
MessagePassing.update(aggr_out, ...):是按照类似于$\gamma$进行更新每一个节点嵌入的方法,其接受聚合输出作为第一个参数。

Exercises Answer

  1. row stands for the source nodes of each edge, col for the target nodes
  2. degree counts how many times each node index appears in the given tensor.
  3. because Because GCN uses symmetric normalization. Here, D must be the degree of the node that aggregates messages, i.e. the target node.
  4. deg_inv_sqrt[col] gets the squared degree of the target nodes, [row] gets the squared degree of the source nodes.
  5. x_j stands for the transformed source nodes embedding for each edge. If self.lin denotes identity function, x_j is source nodes embedding for each edge.
  6. code shows below
1
2
def update(self, aggr_out, x_i):
return aggr_out + x_i

Cover image icon by Dewi Sari from Flaticon