图神经网络中的边的预测问题

您所在的位置:网站首页 基于图神经网络的链路预测 图神经网络中的边的预测问题

图神经网络中的边的预测问题

2024-07-12 18:59| 来源: 网络整理| 查看: 265

在探讨图神经网络(GNNs)在边的预测中的应用之前,首先需明确边的权重这一概念。边的权重在图论中指代节点间关系的强度或可能性,这在众多实际应用场景中,如社交网络、交通网络或分子结构网络中,具有重要意义。边的权重可能代表朋友间的亲密度、交通流量或化学键的强度等。在GNNs中,边的权重通常是通过学习得到的,反映了网络对图中节点间关系的认知和理解。 在这里插入图片描述

1. GNNs中边的预测概述

图神经网络(GNNs)在边的预测上的应用核心在于其能够有效捕捉图中节点间复杂的相互作用和依赖关系。边的预测问题可视为一种链接预测任务,目的是判断图中任意两个节点间是否可能存在边,以及这些边的潜在权重。这种预测不仅限于判断边的存在与否,还涉及到预测边的性质,如权重、类型等。GNN通过聚合邻域信息来更新节点的表示,从而能够学习到能反映整个图结构的复杂模式的节点嵌入。这些嵌入随后可用于预测节点间未知的关系,即边的预测。边的预测在多种应用中至关重要,如在推荐系统中预测用户与商品之间的互动,在社交网络中预测潜在的社交连接,在生物网络中预测蛋白质间的相互作用等。

2. 边的权重预测方法

边的权重预测通常依赖于自编码器结构的GNN,包括编码器和解码器两部分。编码器负责将节点的特征信息映射为低维向量表示,而解码器则基于这些向量表示重建图的边信息,包括边的存在与否及其权重。通过这一过程,GNN学习到的节点表示包含了丰富的结构和属性信息,使得模型能够预测未见过的边信息。

2.1 编码器

编码器的作用是将图中的节点编码为低维向量表示。这一过程通常通过图卷积网络(GCN)、图注意力网络(GAT)等GNN变体实现,这些变体能够利用节点的特征信息及其图结构信息。通过堆叠多层GNN层,编码器能够捕获节点的高阶邻域信息,从而生成能够反映复杂图结构和节点属性的综合性嵌入。

2.2 解码器

解码器的任务是基于节点的向量表示预测节点对之间的边信息,包括边的存在与否以及其权重。这一过程可以通过简单的点乘操作来实现,即计算两个节点向量的点乘来得到边的预测权重。解码器也可以采用更复杂的机制,如基于学习的相似性度量函数,来预测边的信息。

3. 训练与优化

在GNN的训练过程中,目标是通过优化模型参数来最小化预测误差,这通常涉及到对节点表示的学习和边预测误差的最小化。训练过程中常见的挑战包括处理数据中的不平衡性,特别是在边的预测任务中,因为实际图数据中存在边的节点对远少于不存在边的节点对。为了解决这一问题,负采样策略被广泛应用,即随机选择不存在边的节点对作为负样本参与训练,以平衡正负样本比例。此外,过拟合是另一个常见的挑战,特别是当图数据规模相对较小或者边信息稀疏时。为了缓解过拟合,可以采用正则化技术、dropout等方法。

import os.path as osp # 导入os.path模块,并重命名为osp,用于处理文件路径 import torch # 导入PyTorch库 from sklearn.metrics import roc_auc_score # 从scikit-learn库导入roc_auc_score函数,用于评估模型 import torch_geometric.transforms as T # 导入PyTorch Geometric的transforms模块,用于图数据的变换 from torch_geometric.datasets import Planetoid # 导入PyTorch Geometric的Planetoid数据集 from torch_geometric.nn import GCNConv # 导入PyTorch Geometric的GCNConv图卷积层 from torch_geometric.utils import negative_sampling # 导入负采样工具函数 # 检测可用的设备,优先使用CUDA,其次是MPS(苹果的Metal Performance Shaders),最后是CPU if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') # 定义数据变换,包括特征归一化、设备转移和随机链接分割 transform = T.Compose([ T.NormalizeFeatures(), # 特征归一化 T.ToDevice(device), # 转移到指定设备 T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False), # 随机链接分割 ]) # 设置数据集路径并加载Planetoid数据集,这里以Cora数据集为例 path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, name='Cora', transform=transform) # 数据集被分割为训练集、验证集和测试集 train_data, val_data, test_data = dataset[0] # 定义图卷积网络模型 class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) # 第一层图卷积 self.conv2 = GCNConv(hidden_channels, out_channels) # 第二层图卷积 def encode(self, x, edge_index): # 编码函数,用于节点特征的转换 x = self.conv1(x, edge_index).relu() # 第一层卷积后使用ReLU激活函数 return self.conv2(x, edge_index) # 第二层卷积 def decode(self, z, edge_label_index): # 解码函数,用于预测边是否存在 return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1) def decode_all(self, z): # 解码所有节点对的函数 prob_adj = z @ z.t() # 计算节点特征的内积作为边的预测概率 return (prob_adj > 0).nonzero(as_tuple=False).t() # 返回概率大于0的边 # 初始化模型、优化器和损失函数 model = Net(dataset.num_features, 128, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss() # 定义训练函数 def train(): model.train() # 设置模型为训练模式 optimizer.zero_grad() # 清空梯度 z = model.encode(train_data.x, train_data.edge_index) # 对训练数据进行编码 # 对每个训练周期进行一次新的负采样 neg_edge_index = negative_sampling( edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, num_neg_samples=train_data.edge_label_index.size(1), method='sparse') # 合并正样本和负样本的边索引 edge_label_index = torch.cat( [train_data.edge_label_index, neg_edge_index], dim=-1, ) # 创建边的标签,正样本为1,负样本为0 edge_label = torch.cat([ train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1)) ], dim=0) out = model.decode(z, edge_label_index).view(-1) # 解码边的存在概率 loss = criterion(out, edge_label) # 计算损失 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 return loss # 定义测试函数,用于评估模型在验证集和测试集上的性能 @torch.no_grad() # 不计算梯度,以节省计算资源 def test(data): model.eval() # 设置模型为评估模式 z = model.encode(data.x, data.edge_index) # 对数据进行编码 out = model.decode(z, data.edge_label_index).view(-1).sigmoid() # 解码并应用Sigmoid函数得到边存在的概率 return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) # 计算并返回ROC AUC分数 # 训练和评估模型 best_val_auc = final_test_auc = 0 for epoch in range(1, 101): # 进行100个训练周期 loss = train() # 训练模型 val_auc = test(val_data) # 在验证集上评估模型 test_auc = test(test_data) # 在测试集上评估模型 if val_auc > best_val_auc: # 更新最佳验证集AUC和对应的测试集AUC best_val_auc = val_auc final_test_auc = test_auc # 打印训练信息 print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, Test: {test_auc:.4f}') # 打印最终测试集上的AUC分数 print(f'Final Test: {final_test_auc:.4f}') # 对测试集数据进行编码并解码所有节点对,用于预测图中所有可能的边 z = model.encode(test_data.x, test_data.edge_index) final_edge_index = model.decode_all(z) 4. 应用场景

GNNs在边的预测中的应用极为广泛,涵盖了社交网络中的好友推荐、生物信息学中的蛋白质相互作用预测、知识图谱的关系预测等领域。通过精确预测边的存在和权重,GNNs有助于揭示隐藏在复杂网络结构中的深层次模式和关系。

4.1 社交网络分析

在社交网络中,GNNs可以预测用户之间可能形成的新的社交连接,或者预测现有连接的强度。这对于好友推荐、社交圈分析等应用至关重要。

4.2 蛋白质相互作用预测

在生物信息学中,GNNs可以预测不同蛋白质之间的相互作用,这对于理解蛋白质的功能和疾病机理研究具有重要意义。

4.3 推荐系统

GNNs能够预测用户和商品之间的潜在连接,即用户可能对哪些商品感兴趣,从而提供个性化的推荐。

4.4 知识图谱补全

在知识图谱中,GNNs可以预测实体间缺失的关系,即链接预测,有助于知识图谱的自动补全和扩展。

参考文献

pytorch_geometric



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3