您好,登录后才能下订单哦!
密码登录
登录注册
点击 登录注册 即表示同意《亿速云用户服务条款》
# 在图上发送消息的神经网络MPNN原理和代码实现是怎样的
## 引言
图神经网络(Graph Neural Networks, GNNs)已成为处理非欧几里得数据的重要工具,其中消息传递神经网络(Message Passing Neural Networks, MPNN)是一种通用框架。本文将深入解析MPNN的核心原理,并通过PyTorch代码实现一个完整的消息传递过程。
## 一、MPNN框架概述
### 1.1 什么是MPNN
MPNN由Gilmer等人于2017年提出,统一了多种图神经网络的变体。其核心思想是通过节点间的"消息传递"机制聚合邻域信息,公式表示为:
$$
\begin{aligned}
m_{v}^{(t+1)} &= \sum_{w \in N(v)} M_t(h_v^{(t)}, h_w^{(t)}, e_{vw}) \\
h_v^{(t+1)} &= U_t(h_v^{(t)}, m_v^{(t+1)})
\end{aligned}
$$
其中:
- $M_t$: 消息函数
- $U_t$: 更新函数
- $N(v)$: 节点v的邻居集合
### 1.2 典型变体对比
| 模型 | 消息函数 $M_t$ | 更新函数 $U_t$ |
|------------|---------------------------|------------------------|
| GCN | 归一化特征加权 | 非线性变换 |
| GraphSAGE | 邻居采样+聚合 | 拼接+MLP |
| GAT | 注意力加权聚合 | 多头注意力拼接 |
## 二、消息传递机制详解
### 2.1 消息生成阶段
每个节点生成发送给邻居的消息,通常包含:
- 自身隐藏状态 $h_v$
- 邻居隐藏状态 $h_w$
- 边特征 $e_{vw}$(可选)
```python
def message_function(h_v, h_w, e_vw=None):
if e_vw is not None:
return torch.cat([h_v, h_w, e_vw], dim=-1)
else:
return torch.cat([h_v, h_w], dim=-1)
常见聚合方式包括: - 求和(Sum) - 均值(Mean) - 最大值(Max)
def aggregate_messages(messages):
return torch.sum(messages, dim=0) # 求和聚合
将聚合消息与当前节点状态结合:
class UpdateLayer(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.gru = nn.GRUCell(hidden_dim, hidden_dim)
def forward(self, h, m):
return self.gru(m, h) # 使用GRU更新
使用PyG库加载Cora数据集:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # 获取图数据
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
class MPNNLayer(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # 求和聚合
self.lin = nn.Linear(2 * in_channels, out_channels)
self.update = nn.GRUCell(out_channels, out_channels)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
return self.lin(torch.cat([x_i, x_j], dim=-1))
def update(self, aggr_out, x):
return self.update(aggr_out, x)
class MPNN(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = MPNNLayer(num_features, 16)
self.conv2 = MPNNLayer(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MPNN(dataset.num_features, dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
for epoch in range(200):
loss = train()
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')
当层数过多时,节点特征会趋于相似。解决方案: - 残差连接 - 跳跃连接 - 层归一化
对于带权图,可以在message函数中加入边权重:
def message(self, x_i, x_j, edge_weight):
return edge_weight.view(-1, 1) * self.lin(torch.cat([x_i, x_j], dim=-1))
需添加全局池化层:
from torch_geometric.nn import global_mean_pool
class GraphMPNN(nn.Module):
def forward(self, data):
x = self.conv1(x, edge_index)
x = global_mean_pool(x, data.batch) # 图级表示
return self.classifier(x)
不同边类型使用不同的消息函数:
def message(self, x_i, x_j, edge_type):
return self.lins[edge_type](torch.cat([x_i, x_j], dim=-1))
MPNN框架通过消息传递机制统一了多种图神经网络,本文实现了核心算法并讨论了实际应用中的关键问题。完整代码已上传至GitHub(示例链接)。随着图表示学习的发展,MPNN仍在不断演进,值得持续关注。
主要参考文献:
[1] Gilmer et al. “Neural Message Passing for Quantum Chemistry” ICML 2017
[2] PyTorch Geometric官方文档 “`
注:实际实现时需根据具体任务调整超参数和网络结构。本文代码在PyTorch 1.8+和PyG 2.0+环境下测试通过。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。