图卷积神经网络入门实战

您所在的位置:网站首页 soe是什么材料 图卷积神经网络入门实战

图卷积神经网络入门实战

#图卷积神经网络入门实战| 来源: 网络整理| 查看: 265

Multi-layer Graph Convolutional Network (GCN) with first-order filters,来源:http://tkipf.github.io/graph-convolutional-networks/ 本文完整代码和数据已经上传到Github,希望大家不吝赐教,感谢! https://github.com/YoungTimes/GNN/tree/master/GCN1. GCN是什么

图卷积神经网络(Graph Convolution Networks, GCN)跟CNN一样是特征提取的工具,CNN在处理规则数据结构(如图片等)方面非常强大。

图像矩阵示意图(Euclidean Structure),图片来源【4】

但是在现实世界中,很多数据结构是不规则的,典型的就是图结构,如社交网络、知识图谱等,GNN就比较擅长处理这类数据。

社交网络拓扑示意(Non Euclidean Structure),图片来源【4】

本文主要通过一个完整的GCN对论文进行分类的例子,来展示GCN的工作过程和原理。

这里我们使用的Cora数据集,该数据集由2708篇论文的特征、分类以及它们之间引用关系的5429条边组成,这些论文的类型被划分为7个类别:Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。

最终实现的目标是,输入一篇论文的特征,就可以输出该论文属于哪个分类。

2. 数据集-Cora Dataset2.1 下载地址

https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

2.2 数据内容

Cora Dataset是对Machine Learning Paper进行分类的数据集,它包含三个文件:

-- README: 对数据集的介绍;

-- cora.cites: 论文之间的引用关系图。文件中每行包含两个Paper ID, 第一个ID是被引用的Paper ID; 第二个是引用的Paper ID。格式如下:

-- cora.content: 包含了2708篇Paper的信息,每行的数据格式如下: + 。paper id是论文的唯一标识;word_attributes是是一个维度为1433的词向量,词向量的每个元素对应一个词,0表示该元素对应的词不在Paper中,1表示该元素对应的词在Paper中。class_label是论文的类别,每篇Paper被映射到如下7个分类之一: Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。

先看看Cora Dataset中的数据是什么样的...

import pandas as pd import numpy as np # 导入数据:分隔符为Tab raw_data_content = pd.read_csv('data/cora/cora.content',sep = '\t',header = None) # [2708 * 1435] (row, col) = raw_data_content.shape print("Cora Contents’s Row: {}, Col: {}".format(row, col)) print("=============================================") # 每行是1435维的向量,第一维是论文的ID,最后一维是论文的Label raw_data_sample = raw_data_content.head(3) features_sample =raw_data_sample.iloc[:,1:-1] labels_sample = raw_data_sample.iloc[:, -1] labels_onehot_sample = pd.get_dummies(labels_sample) print("features:{}".format(features_sample)) print("=============================================") print("labels:{}".format(labels_sample)) print("=============================================") print("labels one hot:{}".format(labels_onehot_sample)) Cora Contents’s Row: 2708, Col: 1435 ============================================= features: 1 2 3 4 5 6 7 8 9 10 ... 1424 \ 0 0 0 0 0 0 0 0 0 0 0 ... 0 1 0 0 0 0 0 0 0 0 0 0 ... 0 2 0 0 0 0 0 0 0 0 0 0 ... 0 1425 1426 1427 1428 1429 1430 1431 1432 1433 0 0 0 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 [3 rows x 1433 columns] ============================================= labels:0 Neural_Networks 1 Rule_Learning 2 Reinforcement_Learning Name: 1434, dtype: object ============================================= labels one hot: Neural_Networks Reinforcement_Learning Rule_Learning 0 1 0 0 1 0 0 1 2 0 1 0 raw_data_cites = pd.read_csv('data/cora/cora.cites',sep = '\t',header = None) # [5429 * 2] (row, col) = raw_data_cites.shape print("Cora Cites’s Row: {}, Col: {}".format(row, col)) print("=============================================") raw_data_cites_sample = raw_data_cites.head(10) print(raw_data_cites_sample) print("=============================================") # raw_data_cites.head(10).values.flatten().tolist() # Convert Cite to adj matrix idx = np.array(raw_data_content.iloc[:, 0], dtype=np.int32) idx_map = {j: i for i, j in enumerate(idx)} edge_indexs = np.array(list(map(idx_map.get, raw_data_cites.values.flatten())), dtype=np.int32) edge_indexs = edge_indexs.reshape(raw_data_cites.shape) adjacency = sp.coo_matrix((np.ones(len(edge_indexs)), (edge_indexs[:, 0], edge_indexs[:, 1])), shape=(edge_indexs.shape[0], edge_indexs.shape[0]), dtype="float32") print(adjacency) Cora Cites’s Row: 5429, Col: 2 0 1 0 35 1033 1 35 103482 2 35 103515 3 35 1050679 4 35 1103960 ... ... ... 5424 853116 19621 5425 853116 853155 5426 853118 1140289 5427 853155 853118 5428 954315 1155073 [5429 rows x 2 columns] ============================================= (163, 402) 1.0 (163, 659) 1.0 (163, 1696) 1.0 (163, 2295) 1.0 (163, 1274) 1.0 (163, 1286) 1.0 (163, 1544) 1.0 (163, 2600) 1.0 (163, 2363) 1.0 (163, 1905) 1.0 (163, 1611) 1.0 (163, 141) 1.0 (163, 1807) 1.0 (163, 1110) 1.0 (163, 174) 1.0 (163, 2521) 1.0 (163, 1792) 1.0 (163, 1675) 1.0 (163, 1334) 1.0 (163, 813) 1.0 (163, 1799) 1.0 (163, 1943) 1.0 (163, 2077) 1.0 (163, 765) 1.0 (163, 769) 1.0 : : (2228, 1093) 1.0 (2228, 1094) 1.0 (2228, 2068) 1.0 (2228, 2085) 1.0 (2694, 2331) 1.0 (617, 226) 1.0 (422, 1691) 1.0 (2142, 2096) 1.0 (1477, 1252) 1.0 (1485, 1252) 1.0 (2185, 2109) 1.0 (2117, 2639) 1.0 (1211, 1247) 1.0 (1884, 745) 1.0 (1884, 1886) 1.0 (1884, 1902) 1.0 (1885, 745) 1.0 (1885, 1884) 1.0 (1885, 1886) 1.0 (1885, 1902) 1.0 (1886, 745) 1.0 (1886, 1902) 1.0 (1887, 2258) 1.0 (1902, 1887) 1.0 (837, 1686) 1.03. 构造训练集、测试集和验证集

这里使用[0, 150)个数据作为训练集合,[150, 500)个数据作为验证集,[500, 2708)个数据作为测试集,实现上使用掩码(train_mask、val_mask、test_mask)的形式来区分训练集、验证集和测试集。

train_index = np.arange(150) val_index = np.arange(150, 500) test_index = np.arange(500, 2708) train_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool) val_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool) test_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool) train_mask[train_index] = True val_mask[val_index] = True test_mask[test_index] = True 4. GCN核心网络模型

图(Graph)其实数据结构中最重要的概念之一,对,没错,图神经网的图(Graph)跟数据结构中的图(Graph)是一回事。假设神经网络的输入图(Graph)包含N个节点(Node),每个节点有d个特征,则所有这些节点的特征组成一个Nxd维的矩阵X;两个节点间的邻接关系组成一个NxN的邻接矩阵A(adjacency),则X和A就构成了图神经网络的输入。

4.1 核心公式:

GCN神经网络相邻两层之间传播的核心公式如下:

H^{l+1}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right) \tag{1}

其中:

\tilde{A} = A + I是图G的邻接矩阵加上自连接,I是单位矩阵;

\tilde{D}是度矩阵(Degree Matrix), 计算公式为:\tilde{D}_{i i}=\sum_{j} \tilde{A}_{i j};

H是每一层的节点特征;

W是神经网络的待训练的权重参数;

\sigma是激活函数,常用的激活函数有softmax、relu等;

我们先看下GCN网络的传播过程,再反过来看公式中详细含义。

4.2 网络传播过程

如下所示的5个Node组成的无向图的图结构。

它对应的邻接矩阵A为:

A = \left[ \begin{matrix} 0 & 1 & 0 & 0 & 1\\ 1 & 0 & 1 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 1 & 0 & 1 \\ 1 & 0 & 0 & 1 & 0 \end{matrix} \right] \\

在邻接矩阵基础上加上自环,即与单位矩阵I相加,得到:

\tilde{A} = \left[ \begin{matrix} 1 & 1 & 0 & 0 & 1\\ 1 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 1 \\ 1 & 0 & 0 & 1 & 1 \end{matrix} \right] \\

简单期间,这里假设每个Graph中每个Node的特征都是一维的,X=[1, 2, 3, 4, 5]^T,

\tilde{A} * X = \left[ \begin{matrix} 1 * 1 + 1 * 2 + 0 * 3 + 0 * 4 + 1 * 5 \\ 1 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 0 * 5 \\ 0 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 0 * 5 \\ 0 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 1 * 5 \\ 1 * 1 + 0 * 2 + 0 * 3 + 1 * 4 + 1 * 5 \\ \end{matrix} \right] \\

从上式看到了什么,对,就是各个节点都将与其一阶相邻节点的信息融合到自己的节点中,这也是公式(1)的本质所在。神经网络传播的过程,实际上就是图(Graph)中各个节点(Node)不断聚合邻居节点信息的过程。 通过两次连乘\tilde{A} * \tilde{A} * X也就实现各个Node融合自己二阶邻居节点的信息。

到这里我们应该就可以理解为什么要将邻接矩阵加上单位矩阵

因为邻接矩阵的对角线的值都是0,所以如果用邻接矩阵直接与特征矩阵相乘,就将节点自身的信息丢失了。所以为了保留节点(Node)的自身特征,需要将邻接矩阵加上单位矩阵。

既然\tilde{A} * X就可以实现图网络中的信息聚合,那么\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} 是什么意思?

这里只从直觉上说明这个问题,(事实上,我还没有来得及去详细推导公式,后续有时间补上,嘿嘿)

假设节点A的邻居只有B,为了对A进行信息聚合,最直接的方法就是平均贡献,即: New_A = 0.5* A + 0.5 * B,这样的做法看起来很合理,但是如果B的邻居非常多(极端情况是,它与图Graph上的所有其它节点都有连接),经过特征聚合之后,图(Graph)上许多的特征就会非常相似,因此需要考虑节点(Node)的度,不要让度过大的节点(Node)贡献过大。

两层的GCN网络实现如下:

from __future__ import print_function import tensorflow as tf class GraphConvolution(tf.keras.layers.Layer): """Basic graph convolution layer as in https://arxiv.org/abs/1609.02907""" def __init__(self, units, support=1, activation=None, use_bias=True kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, ): super(GraphConvolution, self).__init__() self.units = units self.use_bias = use_bias self.kernel_initializer = kernel_initializer self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer self.supports_masking = True self.support = support assert support >= 1 def build(self, input_shapes): features_shape = input_shapes[0] assert len(features_shape) == 2 input_dim = features_shape[1] self.kernel = self.add_weight(shape = (input_dim * self.support, self.units), initializer = self.kernel_initializer, name = 'kernel', regularizer = self.kernel_regularizer) if self.use_bias: self.bias = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, name='bias', regularizer = self.kernel_regularizer) else: self.bias = None self.built = True def call(self, inputs, mask=None): features = inputs[0] basis = inputs[1:] supports = list() for i in range(self.support): supports.append(K.dot(basis[i], features)) supports = K.concatenate(supports, axis=1) output = K.dot(supports, self.kernel) if self.bias: output += self.bias return self.activation(output) 5. 网络训练过程5.1 准备训练数据

数据处理的细节前面都大概提过,这里需要注意的是,在数据处理的过程中,还需要对每个Node的Feature做归一化处理。

from graph import GraphConvolutionLayer, GraphConvolutionModel from dataset import CoraData import time import tensorflow as tf import matplotlib.pyplot as plt dataset = CoraData() features, labels, adj, train_mask, val_mask, test_mask = dataset.data() graph = [features, adj] Process data ... Loading cora dataset... Dataset has 2708 nodes, 2708 edges, 1433 features.5.2 Loss计算

Loss函数中,只对训练数据(train_mask为True)进行Loss计算。

loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True) def loss(model, x, y, train_mask, training): y_ = model(x, training=training) test_mask_logits = tf.gather_nd(y_, tf.where(train_mask)) masked_labels = tf.gather_nd(y, tf.where(train_mask)) return loss_object(y_true=masked_labels, y_pred=test_mask_logits) def grad(model, inputs, targets, train_mask): with tf.GradientTape() as tape: loss_value = loss(model, inputs, targets, train_mask, training=True) return loss_value, tape.gradient(loss_value, model.trainable_variables) 5.3 实际训练过程def test(mask): logits = model(graph) test_mask_logits = tf.gather_nd(logits, tf.where(mask)) masked_labels = tf.gather_nd(labels, tf.where(mask)) ll = tf.math.equal(tf.math.argmax(masked_labels, -1), tf.math.argmax(test_mask_logits, -1)) accuarcy = tf.reduce_mean(tf.cast(ll, dtype=tf.float64)) return accuarcy model = GraphConvolutionModel() optimizer=tf.keras.optimizers.Adam(learning_rate=0.01, decay=5e-5) # 记录过程值,以便最后可视化 train_loss_results = [] train_accuracy_results = [] train_val_results = [] train_test_results = [] num_epochs = 200 for epoch in range(num_epochs): loss_value, grads = grad(model, graph, labels, train_mask) optimizer.apply_gradients(zip(grads, model.trainable_variables)) accuarcy = test(train_mask) val_acc = test(val_mask) test_acc = test(test_mask) train_loss_results.append(loss_value) train_accuracy_results.append(accuarcy) train_val_results.append(val_acc) train_test_results.append(test_acc) print("Epoch {} loss={} accuracy={} val_acc={} test_acc={}".format(epoch, loss_value, accuarcy, val_acc, test_acc)) Epoch 0 loss=1.9472886323928833 accuracy=0.4066666666666667 val_acc=0.3142857142857143 test_acc=0.2817028985507246 Epoch 1 loss=1.9314587116241455 accuracy=0.4866666666666667 val_acc=0.3742857142857143 test_acc=0.33106884057971014 Epoch 2 loss=1.9133251905441284 accuracy=0.5 val_acc=0.38571428571428573 test_acc=0.34782608695652173 Epoch 3 loss=1.8908278942108154 accuracy=0.5266666666666666 val_acc=0.3942857142857143 test_acc=0.3496376811594203 Epoch 4 loss=1.8662141561508179 accuracy=0.5533333333333333 val_acc=0.3942857142857143 test_acc=0.3423913043478261 Epoch 5 loss=1.8400791883468628 accuracy=0.56 val_acc=0.38285714285714284 test_acc=0.3401268115942029 Epoch 6 loss=1.8119205236434937 accuracy=0.5866666666666667 val_acc=0.38571428571428573 test_acc=0.33605072463768115 Epoch 7 loss=1.78205144405365 accuracy=0.6066666666666667 val_acc=0.37714285714285717 test_acc=0.33016304347826086 Epoch 8 loss=1.751450777053833 accuracy=0.6066666666666667 val_acc=0.38857142857142857 test_acc=0.33016304347826086 Epoch 9 loss=1.7200360298156738 accuracy=0.6066666666666667 val_acc=0.4057142857142857 test_acc=0.3342391304347826 Epoch 10 loss=1.6870578527450562 accuracy=0.6333333333333333 val_acc=0.42 test_acc=0.3428442028985507 Epoch 11 loss=1.6523456573486328 accuracy=0.64 val_acc=0.43142857142857144 test_acc=0.34646739130434784 Epoch 12 loss=1.616371512413025 accuracy=0.6333333333333333 val_acc=0.44 test_acc=0.35190217391304346 Epoch 13 loss=1.579743504524231 accuracy=0.64 val_acc=0.44857142857142857 test_acc=0.360054347826087 Epoch 14 loss=1.5426799058914185 accuracy=0.64 val_acc=0.4542857142857143 test_acc=0.36594202898550726 Epoch 15 loss=1.5049867630004883 accuracy=0.6466666666666666 val_acc=0.46285714285714286 test_acc=0.3686594202898551 Epoch 16 loss=1.466316819190979 accuracy=0.6666666666666666 val_acc=0.46285714285714286 test_acc=0.37273550724637683 Epoch 17 loss=1.4266818761825562 accuracy=0.6733333333333333 val_acc=0.4857142857142857 test_acc=0.37726449275362317 Epoch 18 loss=1.3862168788909912 accuracy=0.6866666666666666 val_acc=0.5 test_acc=0.3808876811594203 Epoch 19 loss=1.3451327085494995 accuracy=0.7266666666666667 val_acc=0.5114285714285715 test_acc=0.39221014492753625 Epoch 20 loss=1.3035770654678345 accuracy=0.7533333333333333 val_acc=0.5257142857142857 test_acc=0.396286231884058 Epoch 21 loss=1.2615602016448975 accuracy=0.7866666666666666 val_acc=0.5342857142857143 test_acc=0.40806159420289856 Epoch 22 loss=1.2191429138183594 accuracy=0.8 val_acc=0.5457142857142857 test_acc=0.40851449275362317 Epoch 23 loss=1.1763759851455688 accuracy=0.82 val_acc=0.56 test_acc=0.41893115942028986 Epoch 24 loss=1.133314609527588 accuracy=0.82 val_acc=0.5742857142857143 test_acc=0.4316123188405797 Epoch 25 loss=1.09010648727417 accuracy=0.8666666666666667 val_acc=0.5771428571428572 test_acc=0.4429347826086957 Epoch 26 loss=1.0468487739562988 accuracy=0.8666666666666667 val_acc=0.5942857142857143 test_acc=0.452445652173913 Epoch 27 loss=1.0036686658859253 accuracy=0.9 val_acc=0.6028571428571429 test_acc=0.46195652173913043 Epoch 28 loss=0.9607122540473938 accuracy=0.9466666666666667 val_acc=0.6171428571428571 test_acc=0.47101449275362317 Epoch 29 loss=0.9181469678878784 accuracy=0.9666666666666667 val_acc=0.6257142857142857 test_acc=0.47690217391304346 Epoch 30 loss=0.8761056661605835 accuracy=0.9666666666666667 val_acc=0.6371428571428571 test_acc=0.48777173913043476 Epoch 31 loss=0.8347143530845642 accuracy=0.9666666666666667 val_acc=0.6485714285714286 test_acc=0.4986413043478261 Epoch 32 loss=0.79410320520401 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5140398550724637 Epoch 33 loss=0.7544015645980835 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5253623188405797 Epoch 34 loss=0.7157045602798462 accuracy=0.9666666666666667 val_acc=0.6628571428571428 test_acc=0.5335144927536232 Epoch 35 loss=0.6780853271484375 accuracy=0.9666666666666667 val_acc=0.6657142857142857 test_acc=0.5443840579710145 Epoch 36 loss=0.6416434049606323 accuracy=0.98 val_acc=0.6771428571428572 test_acc=0.5498188405797102 Epoch 37 loss=0.6064579486846924 accuracy=0.9866666666666667 val_acc=0.6685714285714286 test_acc=0.5588768115942029 Epoch 38 loss=0.5725471377372742 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5706521739130435 Epoch 39 loss=0.5399670600891113 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5765398550724637 Epoch 40 loss=0.5087469220161438 accuracy=1.0 val_acc=0.68 test_acc=0.5806159420289855 Epoch 41 loss=0.4788845181465149 accuracy=1.0 val_acc=0.68 test_acc=0.5869565217391305 Epoch 42 loss=0.4503992795944214 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.5892210144927537 Epoch 43 loss=0.4233132302761078 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5932971014492754 Epoch 44 loss=0.3976200222969055 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5960144927536232 Epoch 45 loss=0.3733169138431549 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5973731884057971 Epoch 46 loss=0.3503628969192505 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449 Epoch 47 loss=0.32872718572616577 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6014492753623188 Epoch 48 loss=0.3083726167678833 accuracy=1.0 val_acc=0.7 test_acc=0.6023550724637681 Epoch 49 loss=0.28924670815467834 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174 Epoch 50 loss=0.27130240201950073 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6023550724637681 Epoch 51 loss=0.2544911503791809 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174 Epoch 52 loss=0.23875992000102997 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6041666666666666 Epoch 53 loss=0.22406704723834991 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6059782608695652 Epoch 54 loss=0.21036122739315033 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6077898550724637 Epoch 55 loss=0.1975751519203186 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6073369565217391 Epoch 56 loss=0.18565499782562256 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6064311594202898 Epoch 57 loss=0.17455774545669556 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6073369565217391 Epoch 58 loss=0.16423100233078003 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6082427536231884 Epoch 59 loss=0.15462662279605865 accuracy=1.0 val_acc=0.7 test_acc=0.6096014492753623 Epoch 60 loss=0.14569756388664246 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363 Epoch 61 loss=0.13739608228206635 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363 Epoch 62 loss=0.12968382239341736 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6127717391304348 Epoch 63 loss=0.12251587212085724 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6127717391304348 Epoch 64 loss=0.11585451662540436 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6118659420289855 Epoch 65 loss=0.10966223478317261 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609 Epoch 66 loss=0.10390333086252213 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609 Epoch 67 loss=0.09854617714881897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6109601449275363 Epoch 68 loss=0.09356522560119629 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6105072463768116 Epoch 69 loss=0.08893042057752609 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6100543478260869 Epoch 70 loss=0.08461488038301468 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869 Epoch 71 loss=0.08059524744749069 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6105072463768116 Epoch 72 loss=0.07684874534606934 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869 Epoch 73 loss=0.0733552798628807 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623 Epoch 74 loss=0.07009640336036682 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869 Epoch 75 loss=0.06705603748559952 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623 Epoch 76 loss=0.06421645730733871 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6091485507246377 Epoch 77 loss=0.06155867129564285 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131 Epoch 78 loss=0.05906983092427254 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131 Epoch 79 loss=0.05673719570040703 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131 Epoch 80 loss=0.054548606276512146 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377 Epoch 81 loss=0.052494876086711884 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377 Epoch 82 loss=0.050564948469400406 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884 Epoch 83 loss=0.04874930530786514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884 Epoch 84 loss=0.04704001545906067 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6086956521739131 Epoch 85 loss=0.04542906954884529 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884 Epoch 86 loss=0.04390912503004074 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6077898550724637 Epoch 87 loss=0.04247550293803215 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6082427536231884 Epoch 88 loss=0.04112052172422409 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391 Epoch 89 loss=0.03983796760439873 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391 Epoch 90 loss=0.03862294182181358 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391 Epoch 91 loss=0.03747102618217468 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6064311594202898 Epoch 92 loss=0.03637789562344551 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6059782608695652 Epoch 93 loss=0.035339970141649246 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116 Epoch 94 loss=0.03435278683900833 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116 Epoch 95 loss=0.03341297432780266 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666 Epoch 96 loss=0.03251757472753525 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.603713768115942 Epoch 97 loss=0.03166373074054718 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6032608695652174 Epoch 98 loss=0.03084862045943737 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666 Epoch 99 loss=0.030070148408412933 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914 Epoch 100 loss=0.029325664043426514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666 Epoch 101 loss=0.028613392263650894 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914 Epoch 102 loss=0.02793121710419655 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 103 loss=0.027277110144495964 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 104 loss=0.02665024995803833 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 105 loss=0.026048408821225166 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 106 loss=0.02547014318406582 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 107 loss=0.024914324283599854 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 108 loss=0.02437940239906311 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 109 loss=0.023864200338721275 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 110 loss=0.023367829620838165 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 111 loss=0.02288944460451603 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 112 loss=0.022427884861826897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942 Epoch 113 loss=0.021982161328196526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 114 loss=0.021551571786403656 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666 Epoch 115 loss=0.0211354810744524 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6032608695652174 Epoch 116 loss=0.020733091980218887 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6028079710144928 Epoch 117 loss=0.020343618467450142 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6019021739130435 Epoch 118 loss=0.019966619089245796 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188 Epoch 119 loss=0.019601713865995407 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188 Epoch 120 loss=0.01924823224544525 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942 Epoch 121 loss=0.01890559121966362 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942 Epoch 122 loss=0.018573053181171417 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942 Epoch 123 loss=0.018250416964292526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 124 loss=0.017937207594513893 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695 Epoch 125 loss=0.017632879316806793 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695 Epoch 126 loss=0.017337223514914513 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449 Epoch 127 loss=0.017049791291356087 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203 Epoch 128 loss=0.01677037589251995 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203 Epoch 129 loss=0.016498537734150887 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203 Epoch 130 loss=0.016234181821346283 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203 Epoch 131 loss=0.01597682572901249 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203 Epoch 132 loss=0.015726083889603615 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449 Epoch 133 loss=0.015481753274798393 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 134 loss=0.01524385903030634 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 135 loss=0.015011751092970371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942 Epoch 136 loss=0.014785613864660263 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 137 loss=0.014565042220056057 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 138 loss=0.014349889941513538 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 139 loss=0.014139854349195957 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 140 loss=0.013934796676039696 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 141 loss=0.013734581880271435 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 142 loss=0.013539088889956474 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 143 loss=0.013348042033612728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 144 loss=0.01316142175346613 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203 Epoch 145 loss=0.012978975661098957 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 146 loss=0.01280051190406084 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449 Epoch 147 loss=0.012626114301383495 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449 Epoch 148 loss=0.012455460615456104 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 149 loss=0.012288510799407959 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 150 loss=0.012125165201723576 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 151 loss=0.011965337209403515 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 152 loss=0.011808915995061398 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 153 loss=0.011655798181891441 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 154 loss=0.011505785398185253 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 155 loss=0.011358906514942646 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 156 loss=0.011214920319616795 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 157 loss=0.011073877103626728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 158 loss=0.010935710743069649 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 159 loss=0.010800261981785297 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 160 loss=0.010667545720934868 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 161 loss=0.01053738035261631 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 162 loss=0.010409791953861713 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 163 loss=0.010284650139510632 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 164 loss=0.010161920450627804 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 165 loss=0.010041462257504463 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 166 loss=0.009923247620463371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 167 loss=0.009807263500988483 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 168 loss=0.009693419560790062 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 169 loss=0.009581669233739376 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 170 loss=0.009471924044191837 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695 Epoch 171 loss=0.009364166297018528 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 172 loss=0.00925836805254221 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 173 loss=0.009154461324214935 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 174 loss=0.009052390232682228 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449 Epoch 175 loss=0.008952111005783081 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203 Epoch 176 loss=0.008853581734001637 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203 Epoch 177 loss=0.008756755851209164 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203 Epoch 178 loss=0.008661641739308834 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203 Epoch 179 loss=0.008568093180656433 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957 Epoch 180 loss=0.008476126939058304 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957 Epoch 181 loss=0.008385734632611275 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957 Epoch 182 loss=0.008296912536025047 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957 Epoch 183 loss=0.008209548890590668 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 184 loss=0.008123602718114853 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 185 loss=0.008039114996790886 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 186 loss=0.007956001907587051 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 187 loss=0.00787423737347126 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 188 loss=0.0077937874011695385 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 189 loss=0.00771462032571435 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 190 loss=0.007636724505573511 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 191 loss=0.0075600543059408665 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 192 loss=0.007484584581106901 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 193 loss=0.007410289254039526 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 194 loss=0.0073371464386582375 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 195 loss=0.007265167310833931 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 196 loss=0.007194266188889742 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 197 loss=0.007124484982341528 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 198 loss=0.007055748254060745 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203 Epoch 199 loss=0.006988039705902338 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203

可以看到,经过200次迭代后,最终GCN网络在验证集上的准确率达到70%左右,在测试集中的Accuracy达到了60%左右。

5.4 训练过程可视化# 训练过程可视化 fig, axes = plt.subplots(4, sharex=True, figsize=(12, 8)) fig.suptitle('Training Metrics') axes[0].set_ylabel("Loss", fontsize=14) axes[0].plot(train_loss_results) axes[1].set_ylabel("Accuracy", fontsize=14) axes[1].plot(train_accuracy_results) axes[2].set_ylabel("Val Acc", fontsize=14) axes[2].plot(train_val_results) axes[3].set_ylabel("Test Acc", fontsize=14) axes[3].plot(train_test_results) plt.show() ======================================================2020.11.1更新:

感谢@二三同学指正,为edges增加了对称关系,val_acc提升到83.7%,test_acc提升到76.6%,代码已更新到github。

Epoch 188 loss=0.034059930592775345 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 189 loss=0.03361804038286209 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 190 loss=0.0331842303276062 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 191 loss=0.032758601009845734 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 192 loss=0.032341040670871735 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 193 loss=0.031931228935718536 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7676630616188049 Epoch 194 loss=0.03152924403548241 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7676630616188049 Epoch 195 loss=0.031134601682424545 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7676630616188049 Epoch 196 loss=0.030747264623641968 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7672101259231567 Epoch 197 loss=0.030366968363523483 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7672101259231567 Epoch 198 loss=0.02999333292245865 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7667572498321533 Epoch 199 loss=0.029626576229929924 accuracy=1.0 val_acc=0.8371428847312927 test_acc=0.7663043737411499参考材料

1.https://github.com/FighterLYL/GraphNeuralNetwork

2.https://github.com/tkipf/keras-gcn

3.https://blog.csdn.net/qq_41995574/article/details/99712339

4.https://blog.csdn.net/weixin_40013463/article/details/81089223

推荐阅读

注:本文首发于微信公众号【半杯茶的小酒杯】,转载请注明出处,谢谢!



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3