图神经网络GNN算法基本原理详解

目录
  • 前言
  • 1. 数据
  • 2. 变量定义
  • 3. GNN算法
    • 3.1 Forward
    • 3.2 Backward
  • 4.总结与展望

前言

本文结合一个具体的无向图来对最简单的一种GNN进行推导。本文第一部分是数据介绍,第二部分为推导过程中需要用的变量的定义,第三部分是GNN的具体推导过程,最后一部分为自己对GNN的一些看法与总结。

1. 数据

利用networkx简单生成一个无向图:

# -*- coding: utf-8 -*-
"""
@Time : 2021/12/21 11:23
@Author :KI
@File :gnn_basic.py
@Motto:Hungry And Humble

"""
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

G = nx.Graph()
node_features = [[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]
edges = [(1, 2), (1, 3), (2, 4), (2, 5), (1, 3), (3, 5), (3, 4)]
edge_features = [[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]
colors = []
edge_colors = []

# add nodes
for i in range(1, len(node_features) + 1):
    G.add_node(i, feature=str(i) + ':(' + str(node_features[i-1][0]) + ',' + str(node_features[i-1][1]) + ')')
    colors.append('#DCBB8A')

# add edges
for i in range(1, len(edge_features) + 1):
    G.add_edge(edges[i-1][0], edges[i-1][1], feature='(' + str(edge_features[i-1][0]) + ',' + str(edge_features[i-1][1]) + ')')
    edge_colors.append('#3CA9C4')

# draw
fig, ax = plt.subplots()

pos = nx.spring_layout(G)
nx.draw(G, pos=pos, node_size=2000, node_color=colors, edge_color='black')
node_labels = nx.get_node_attributes(G, 'feature')
nx.draw_networkx_labels(G, pos=pos, labels=node_labels, node_size=2000, node_color=colors, font_color='r', font_size=14)
edge_labels = nx.get_edge_attributes(G, 'feature')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=14, font_color='#7E8877')

ax.set_facecolor('deepskyblue')
ax.axis('off')
fig.set_facecolor('deepskyblue')
plt.show()

如下所示:

其中,每一个节点都有自己的一些特征,比如在社交网络中,每个节点(用户)有性别以及年龄等特征。

5个节点的特征向量依次为:

[[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]

同样,6条边的特征向量为:

[[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]

2. 变量定义

特征向量实际上也就是节点或者边的标签,这个是图本身的属性,一直保持不变。

3. GNN算法

GNN算法的完整描述如下:Forward向前计算状态,Backward向后计算梯度,主函数通过向前和向后迭代调用来最小化损失。

主函数中:

上述描述只是一个总体的概述,可以略过先不看。

3.1 Forward

早期的GNN都是RecGNN,即循环GNN。这种类型的GNN基于信息传播机制: GNN通过不断交换邻域信息来更新节点状态,直到达到稳定均衡。节点的状态向量 x 由以下 f w ​函数来进行周期性更新:

解析上述公式:对于节点 n ,假设为节点1,更新其状态需要以下数据参与:

这里的fw只是形式化的定义,不同的GNN有不同的定义,如随机稳态嵌入(SSE)中定义如下:

由更新公式可知,当所有节点的状态都趋于稳定状态时,此时所有节点的状态向量都包含了其邻居节点和相连边的信息。

这与图嵌入有些类似:如果是节点嵌入,我们最终得到的是一个节点的向量表示,而这些向量是根据随机游走序列得到的,随机游走序列中又包括了节点的邻居信息, 因此节点的向量表示中包含了连接信息。

证明上述更新过程能够收敛需要用到不动点理论,这里简单描述下:

如果我们有以下更新公式:

GNN的Foward描述如下:

解释:

3.2 Backward

在节点嵌入中,我们最终得到了每个节点的表征向量,此时我们就能利用这些向量来进行聚类、节点分类、链接预测等等。

GNN中类似,得到这些节点状态向量的最终形式不是我们的目的,我们的目的是利用这些节点状态向量来做一些实际的应用,比如节点标签预测。

因此,如果想要预测的话,我们就需要一个输出函数来对节点状态进行变换,得到我们要想要的东西:

最容易想到的就是将节点状态向量经过一个前馈神经网络得到输出,也就是说 g w g_w gw​可以是一个FNN,同样的, f w f_w fw​也可以是一个FNN:

我们利用 g w g_w gw​函数对节点 n n n收敛后的状态向量 x n x_n xn​以及其特征向量 l n l_n ln​进行变换,就能得到我们想要的输出,比如某一类别,某一具体的数值等等。

在BP算法中,我们有了输出后,就能算出损失,然后利用损失反向传播算出梯度,最后再利用梯度下降法对神经网络的参数进行更新。

对于某一节点的损失(比如回归)我们可以简单定义如下:

有了z(t)后,我们就能求导了:

z(t)的求解方法在Backward中有描述:

因此,在Backward中需要计算以下导数:

4.总结与展望

本文所讲的GNN是最原始的GNN,此时的GNN存在着不少的问题,比如对不动点隐藏状态的更新比较低效。

由于CNN在CV领域的成功,许多重新定义图形数据卷积概念的方法被提了出来,图卷积神经网络ConvGNN也被提了出来,ConvGNN被分为两大类:频域方法(spectral-based method )和空间域方法(spatial-based method)。2009年,Micheli在继承了来自RecGNN的消息传递思想的同时,在架构上复合非递归层,首次解决了图的相互依赖问题。在过去的几年里还开发了许多替代GNN,包括GAE和STGNN。这些学习框架可以建立在RecGNN、ConvGNN或其他用于图形建模的神经架构上。

GNN是用于图数据的深度学习架构,它将端到端学习与归纳推理相结合,业界普遍认为其有望解决深度学习无法处理的因果推理、可解释性等一系列瓶颈问题,是未来3到5年的重点方向。

因此,不仅仅是GNN,图领域的相关研究都是比较有前景的,这方面的应用也十分广泛,比如推荐系统、计算机视觉、物理/化学(生命科学)、药物发现等等。

以上就是图神经网络GNN算法基本原理详解的详细内容,更多关于图神经网络GNN算法的资料请关注我们其它相关文章!

(0)

相关推荐

  • 图神经网络GNN算法基本原理详解

    目录 前言 1. 数据 2. 变量定义 3. GNN算法 3.1 Forward 3.2 Backward 4.总结与展望 前言 本文结合一个具体的无向图来对最简单的一种GNN进行推导.本文第一部分是数据介绍,第二部分为推导过程中需要用的变量的定义,第三部分是GNN的具体推导过程,最后一部分为自己对GNN的一些看法与总结. 1. 数据 利用networkx简单生成一个无向图: # -*- coding: utf-8 -*- """ @Time : 2021/12/21 11:

  • java图搜索算法之图的对象化描述示例详解

    目录 一.前言 二.什么是图 三.怎么存储一个图的结构 1.邻接矩阵 2.邻接表 3.图对象化表示 四.图的作用 你好,我是小黄,一名独角兽企业的Java开发工程师. 校招收获数十个offer,年薪均20W~40W. 感谢茫茫人海中我们能够相遇, 俗话说:当你的才华和能力,不足以支撑你的梦想的时候,请静下心来学习, 希望优秀的你可以和我一起学习,一起努力,实现属于自己的梦想. 一.前言 对于图来说,我一直以来都似懂非懂 懂的是图的含义,不懂的是图具体的实现 对于当前各大厂面试的图题,不外乎以下几

  • opencv3/C++ PHash算法图像检索详解

    PHash算法即感知哈希算法/Perceptual Hash algorithm,计算基于低频的均值哈希.对每张图像生成一个指纹字符串,通过对该字符串比较可以判断图像间的相似度. PHash算法原理 将图像转为灰度图,然后将图片大小调整为32*32像素并通过DCT变换,取左上角的8*8像素区域.然后计算这64个像素的灰度值的均值.将每个像素的灰度值与均值对比,大于均值记为1,小于均值记为0,得到64位哈希值. PHash算法实现 将图片转为灰度值 将图片尺寸缩小为32*32 resize(src

  • python 图像增强算法实现详解

    使用python编写了共六种图像增强算法: 1)基于直方图均衡化 2)基于拉普拉斯算子 3)基于对数变换 4)基于伽马变换 5)限制对比度自适应直方图均衡化:CLAHE 6)retinex-SSR 7)retinex-MSR其中,6和7属于同一种下的变化. 将每种方法编写成一个函数,封装,可以直接在主函数中调用. 采用同一幅图进行效果对比. 图像增强的效果为: 直方图均衡化:对比度较低的图像适合使用直方图均衡化方法来增强图像细节 拉普拉斯算子可以增强局部的图像对比度 log对数变换对于整体对比度

  • JVM中四种GC算法案例详解

    目录 介绍 引用计数算法(Reference counting) 算法思想: 核心思想: 优点: 缺点: 例子如图: 标记–清除算法(Mark-Sweep) 算法思想: 优点 缺点 例子如图 标记–整理算法 算法思想 优点 缺点 例子 复制算法 算法思想 优点 缺点 总结 介绍 程序在运行过程中,会产生大量的内存垃圾(一些没有引用指向的内存对象都属于内存垃圾,因为这些对象已经无法访问,程序用不了它们了,对程序而言它们已经死亡),为了确保程序运行时的性能,java虚拟机在程序运行的过程中不断地进行

  • Python 经典贪心算法之Prim算法案例详解

    最小生成树的Prim算法也是贪心算法的一大经典应用.Prim算法的特点是时刻维护一棵树,算法不断加边,加的过程始终是一棵树. Prim算法过程: 一条边一条边地加, 维护一棵树. 初始 E = {}空集合, V = {任选的一个起始节点} 循环(n – 1)次,每次选择一条边(v1,v2), 满足:v1属于V , v2不属于V.且(v1,v2)权值最小. E = E + (v1,v2) V = V + v2 最终E中的边是一棵最小生成树, V包含了全部节点. 以下图为例介绍Prim算法的执行过程

  • Python深度学习pytorch神经网络图像卷积运算详解

    目录 互相关运算 卷积层 特征映射 由于卷积神经网络的设计是用于探索图像数据,本节我们将以图像为例. 互相关运算 严格来说,卷积层是个错误的叫法,因为它所表达的运算其实是互相关运算(cross-correlation),而不是卷积运算.在卷积层中,输入张量和核张量通过互相关运算产生输出张量. 首先,我们暂时忽略通道(第三维)这一情况,看看如何处理二维图像数据和隐藏表示.下图中,输入是高度为3.宽度为3的二维张量(即形状为 3 × 3 3\times3 3×3).卷积核的高度和宽度都是2. 注意,

  • 通俗易懂的C++前缀和与差分算法图文详解

    目录 1.前缀和 2.前缀和算法有什么好处? 3.二维前缀和 4.差分 5.一维差分 6.二维差分 1.前缀和 前缀和是指某序列的前n项和,可以把它理解为数学上的数列的前n项和,而差分可以看成前缀和的逆运算.合理的使用前缀和与差分,可以将某些复杂的问题简单化. 2.前缀和算法有什么好处? 先来了解这样一个问题: 输入一个长度为n的整数序列.接下来再输入m个询问,每个询问输入一对l, r.对于每个询问,输出原序列中从第l个数到第r个数的和. 我们很容易想出暴力解法,遍历区间求和. 代码如下: in

  • TensorFlow卷积神经网络AlexNet实现示例详解

    2012年,Hinton的学生Alex Krizhevsky提出了深度卷积神经网络模型AlexNet,它可以算是LeNet的一种更深更宽的版本.AlexNet以显著的优势赢得了竞争激烈的ILSVRC 2012比赛,top-5的错误率降低至了16.4%,远远领先第二名的26.2%的成绩.AlexNet的出现意义非常重大,它证明了CNN在复杂模型下的有效性,而且使用GPU使得训练在可接受的时间范围内得到结果,让CNN和GPU都大火了一把.AlexNet可以说是神经网络在低谷期后的第一次发声,确立了深

  • python机器学习Sklearn实战adaboost算法示例详解

    目录 pandas批量处理体测成绩 adaboost adaboost原理案例举例 弱分类器合并成强分类器 pandas批量处理体测成绩 import numpy as np import pandas as pd from pandas import Series,DataFrame import matplotlib.pyplot as plt data = pd.read_excel("/Users/zhucan/Desktop/18级高一体测成绩汇总.xls") cond =

随机推荐