GCN 图神经网络使用详解 可视化 Pytorch

目录
  • 手动尝试GCN图神经网络
  • 现在让我们更详细地看一下底层图
  • 现在让我们更详细地检查edge_index的属性
  • 嵌入 Karate Club Network
  • 训练 Karate Club Network
  • 总结

手动尝试GCN图神经网络

最近,图上的深度学习已经成为深度学习社区中最热门的研究领域之一。 在这里,图神经网络(GNN)旨在将经典的深度学习概念推广到不规则的结构化数据(与图像或文本形成对比),并使神经网络能够推理出对象及其关系。

本内容介绍一些关于通过基于PyTorch几何(PyG)库的图神经网络对图进行深度学习的基本概念。

PyTorch geometry是流行的深度学习框架PyTorch的扩展库,由各种方法和实用程序组成,以简化图神经网络的实现。

在开始之前,先介绍一下配置环境:

Pytorch: 1.8.0       Cuda: 10.2    Torch-geometric

# 导入使用的模块包
import torch
import networkx as nx
import matplotlib.pyplot as plt

# 定义最后可视化的函数
def visualize(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])

    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None:
            plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    else:
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()

在这里,我们使用一张KarateClub图来进行讲解,这张图描述了一个由34名空手道俱乐部成员组成的社交网络,并记录了俱乐部外成员之间的联系。在这里,我们感兴趣的是检测由成员的交互产生的社区。

KarateClub图

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}') # 1
print(f'Number of features: {dataset.num_features}') # 34
print(f'Number of classes: {dataset.num_classes}') # 4

这里输出的分别是:

  • (1)图的数量、
  • (2)特征的数量
  • (3)种类

在初始化KarateClub数据集之后,我们首先可以检查它的一些属性。

例如,我们可以看到这个数据集只持有一个图,并且这个数据集中的每个节点被分配一个34维的特征向量(唯一地描述空手道俱乐部的成员)。

此外,图中正好包含4个类,它们代表每个节点所属的团体。

现在让我们更详细地看一下底层图

data = dataset[0]  # Get the first graph object.

print(data)
print('==============================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Data(edge_index=[2, 156], train_mask=[34], x=[34, 34], y=[34])
==============================================================
Number of nodes: 34
Number of edges: 156
Average node degree: 4.59
Number of training nodes: 4
Training node label rate: 0.12
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True

PyTorch Geometric 中的每个图形都由单个 Data 对象表示,该对象包含描述其图形表示的所有信息。

我们可以随时通过 print(data) 打印数据对象,以接收有关其属性及其形状的简短摘要:

Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])

我们可以看到该数据对象具有4个属性:

(1)edge_index:属性保存有关图连接性的信息,即每个边缘的源节点索引和目标节点索引的元组。 PyG进一步将

(2)节点特征称为x(为34个节点中的每个节点分配了一个34维特征向量),并且将

(3)节点标签称为y(每个节点被精确地分配为一个类别)。

(4)还有一个名为train_mask的附加属性,它描述了我们已经知道其社区归属的节点。 总共,我们只知道4个节点的基本标签(每个社区一个),任务是推断其余节点的社区分配。数据对象还提供一些实用程序功能来推断基础图的某些基本属性。 例如,我们可以轻松推断图中是否存在孤立的节点(即,任何节点都没有边),图是否包含自环(即(v,v)∈E)或图是否为 无向的(即,对于每个边(v,w)∈E也存在边(w,v)∈E)。

现在让我们更详细地检查edge_index的属性

from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

edge_index = data.edge_index
print(edge_index.t())
tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
         ........

这个edge_index描述了34个人的相关性。通过输出edge_index,我们可以进一步了解PyG内部是如何表示图连通性的。

我们可以看到,对于每条边,edge_index 包含两个节点索引的元组,其中第一个值描述源节点的节点索引,第二个值描述边的目标节点的节点索引。

这种表示被称为COO格式(坐标格式),通常用于表示稀疏矩阵。

PyG使用稀疏矩阵代替以密集表示形式的邻接矩阵A∈{0,1} | V |×| V | ,这是指仅保留A中的条目不为零的坐标/值。

我们可以通过将图转换为networkx库格式来进一步可视化,这种格式除了图形操作功能之外,还实现了用于可视化的强大工具

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
visualize(G, color=data.y)

数据库可视化

灰色、黄色、绿色、蓝色代表四类不同的俱乐部,其中每一个圆圈代表一个人,一共有34个人,每个人之间的关系就如edge_index所描述的那样。

现在,我们要通过在torch.nn.Module类继承中定义我们的网络架构来创建我们的第一个图神经网络

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.

        # Apply a final (linear) classifier.
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)
GCN(
  (conv1): GCNConv(34, 4)
  (conv2): GCNConv(4, 4)
  (conv3): GCNConv(4, 2)
  (classifier): Linear(in_features=2, out_features=4, bias=True)
)

在这里,我们首先在 __init__ 中初始化我们所有的构建块,并定义我们forward网络的计算流程。 我们首先定义并堆叠三个图卷积层,这对应于聚合每个节点周围的 3 个邻域信息(所有节点最多 3个)。 此外,GCNConv 层将节点特征维数减少到 2 ,即 34→4→4→2 。 每个 GCNConv 层都通过 tanh 非线性增强。(可以换成RELU试一试)

之后,我们应用单个线性变换 (torch.nn.Linear) 作为分类器将我们的节点映射到 4 个类/社区中的 1 个。

我们返回最终分类器的输出以及GNN生成的最终节点嵌入。 我们继续通过 GCN() 初始化我们的最终模型,打印我们的模型会生成所有使用的子模块的摘要。

嵌入 Karate Club Network

让我们看看GNN产生的节点嵌入。这里,我们将初始节点特征x和图连通性信息edge_index传递给模型,并可视化其二维嵌入。

model = GCN()

_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize(h, color=data.y)

值得注意的是,即使在训练我们的模型的权重之前,该模型也会产生一个与图中的社区结构非常相似的节点嵌入

相同颜色(社区)的节点在嵌入空间中已经紧密地聚在一起,尽管我们的模型的权值是完全随机初始化的,而且到目前为止我们还没有进行任何训练!由此得出结论,gnn引入了很强的归纳偏置,导致输入图中彼此接近的节点产生类似的嵌入。

训练 Karate Club Network

但我们能做得更好吗? 让我们看一个示例,说明如何根据图中 4 个节点的社区分配知识(每个社区一个)来训练我们的网络参数:

由于我们模型中的所有内容都是可微分和参数化的,我们可以添加一些标签、训练模型并观察嵌入的反应。 在这里,我们使用半监督或转导学习程序:我们只是针对每个类的一个节点进行训练,但允许使用完整的输入图数据。

这个模型训练与任何其他PyTorch模型非常相似。除了定义我们的网络架构之外,我们还定义了一个损失标准(这里是CrossEntropyLoss),并初始化了一个随机梯度优化器(这里是Adam)。之后,我们执行多轮优化,每轮由前向和后向传递来计算我们的模型参数w.r.t.对前向传递的损失的梯度。

import time
from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})'''))

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

可以看到,训练400轮后,它的聚类是比较明显的。正如可以看到的,我们的3层GCN模型管理线性分隔社区和正确分类大多数节点。

此外,我们只用了几行代码就完成了这一切,这要感谢PyTorch geometry库,它帮助我们完成了数据处理和GNN实现。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Python深度学习pytorch神经网络图像卷积运算详解

    目录 互相关运算 卷积层 特征映射 由于卷积神经网络的设计是用于探索图像数据,本节我们将以图像为例. 互相关运算 严格来说,卷积层是个错误的叫法,因为它所表达的运算其实是互相关运算(cross-correlation),而不是卷积运算.在卷积层中,输入张量和核张量通过互相关运算产生输出张量. 首先,我们暂时忽略通道(第三维)这一情况,看看如何处理二维图像数据和隐藏表示.下图中,输入是高度为3.宽度为3的二维张量(即形状为 3 × 3 3\times3 3×3).卷积核的高度和宽度都是2. 注意,

  • 使用pytorch提取卷积神经网络的特征图可视化

    目录 前言 1. 效果图 2. 完整代码 3. 代码说明 4. 可视化梯度,feature 总结 前言 文章中的代码是参考基于Pytorch的特征图提取编写的代码本身很简单这里只做简单的描述. 1. 效果图 先看效果图(第一张是原图,后面的都是相应的特征图,这里使用的网络是resnet50,需要注意的是下面图片显示的特征图是经过放大后的图,原图是比较小的图,因为太小不利于我们观察): 2. 完整代码 import os import torch import torchvision as tv

  • pytorch深度神经网络入门准备自己的图片数据

    目录 正文 一.所有图片放在一个文件夹内 二.不同类别的图片放在不同的文件夹内 正文 图片数据一般有两种情况: 1.所有图片放在一个文件夹内,另外有一个txt文件显示标签. 2.不同类别的图片放在不同的文件夹内,文件夹就是图片的类别. 针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理.下面分别进行说明: 一.所有图片放在一个文件夹内 这里以mnist数据集的10000个te

  • python机器学习GCN图卷积神经网络原理解析

    目录 1. 图信号处理知识 1.1 图的拉普拉斯矩阵 1.1.1 拉普拉斯矩阵的定义及示例 1.1.2 正则化拉普拉斯矩阵 1.2 图上的傅里叶变换 1.3 图信号滤波器 2. 图卷积神经网络 2.1 数学定义 2.2 GCN的理解及时间复杂度 2.3 GCN的优缺点 3. Pytorch代码解析 1. 图信号处理知识 图卷积神经网络涉及到图信号处理的相关知识,也是由图信号处理领域的知识推导发展而来,了解图信号处理的知识是理解图卷积神经网络的基础. 1.1 图的拉普拉斯矩阵 拉普拉斯矩阵是体现图

  • GCN 图神经网络使用详解 可视化 Pytorch

    目录 手动尝试GCN图神经网络 现在让我们更详细地看一下底层图 现在让我们更详细地检查edge_index的属性 嵌入 Karate Club Network 训练 Karate Club Network 总结 手动尝试GCN图神经网络 最近,图上的深度学习已经成为深度学习社区中最热门的研究领域之一. 在这里,图神经网络(GNN)旨在将经典的深度学习概念推广到不规则的结构化数据(与图像或文本形成对比),并使神经网络能够推理出对象及其关系. 本内容介绍一些关于通过基于PyTorch几何(PyG)库

  • Python编程pytorch深度卷积神经网络AlexNet详解

    目录 容量控制和预处理 读取数据集 2012年,AlexNet横空出世.它首次证明了学习到的特征可以超越手工设计的特征.它一举打破了计算机视觉研究的现状.AlexNet使用了8层卷积神经网络,并以很大的优势赢得了2012年的ImageNet图像识别挑战赛. 下图展示了从LeNet(左)到AlexNet(right)的架构. AlexNet和LeNet的设计理念非常相似,但也有如下区别: AlexNet比相对较小的LeNet5要深得多. AlexNet使用ReLU而不是sigmoid作为其激活函数

  • Python可视化Matplotlib折线图plot用法详解

    目录 1.完善原始折线图 - 给图形添加辅助功能 1.1 准备数据并画出初始折线图 1.2 添加自定义x,y刻度 1.3 中文显示问题解决 1.4 添加网格显示 1.5 添加描述信息 1.6 图像保存 2. 在一个坐标系中绘制多个图像 2.1 多次plot 2.2 显示图例 2.3 折线图的应用场景 折线图是数据分析中非常常用的图形.其中,折线图主要是以折线的上升或下降来表示统计数量的增减变化的统计图.用于分析自变量和因变量之间的趋势关系,最适合用于显示随着时间而变化的连续数据,同时还可以看出数

  • 利用Pytorch实现获取特征图的方法详解

    目录 简单加载官方预训练模型 图片预处理 提取单个特征图 提取多个特征图 简单加载官方预训练模型 torchvision.models预定义了很多公开的模型结构 如果pretrained参数设置为False,那么仅仅设定模型结构:如果设置为True,那么会启动一个下载流程,下载预训练参数 如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False) 想查看模型参数可以使用modules和named_modules,其中named_modul

  • Python LeNet网络详解及pytorch实现

    目录 1.LeNet介绍 2.LetNet网络模型 3.pytorch实现LeNet 1.LeNet介绍 LeNet神经网络由深度学习三巨头之一的Yan LeCun提出,他同时也是卷积神经网络 (CNN,Convolutional Neural Networks)之父.LeNet主要用来进行手写字符的识别与分类,并在美国的银行中投入了使用.LeNet的实现确立了CNN的结构,现在神经网络中的许多内容在LeNet的网络结构中都能看到,例如卷积层,Pooling层,ReLU层.虽然LeNet早在20

  • Python LeNet网络详解及pytorch实现

    目录 1.LeNet介绍 2.LetNet网络模型 3.pytorch实现LeNet 1.LeNet介绍 LeNet神经网络由深度学习三巨头之一的Yan LeCun提出,他同时也是卷积神经网络 (CNN,Convolutional Neural Networks)之父.LeNet主要用来进行手写字符的识别与分类,并在美国的银行中投入了使用.LeNet的实现确立了CNN的结构,现在神经网络中的许多内容在LeNet的网络结构中都能看到,例如卷积层,Pooling层,ReLU层.虽然LeNet早在20

  • Python绘制惊艳的桑基图的示例详解

    目录 桑基图简介 什么是桑基图? 如何绘制桑基图? 桑基图绘图基础 调整节点位置和图表宽度 添加有意义的悬停标签 桑基图简介 很多时候,我们需要一种必须可视化数据如何在实体之间流动的情况.例如,以居民如何从一个国家迁移到另一个国家为例.这里演示了有多少居民从英格兰迁移到北爱尔兰.苏格兰和威尔士. 从这个 桑基图 (Sankey)可视化中可以明显看出,从England迁移到Wales的居民多于从Scotland或Northern Ireland迁移的居民. 什么是桑基图? 桑基图通常描绘 从一个实

  • Webpack中雪碧图插件使用详解

    背景 在开发过程中,我们需要用到很多图标,这些图标的大小不是很大,但是每次需要向服务器发送请求,从而加重服务器的负担,尤其是当网站处于高访问量的情况下或网络不稳定的时候,服务器性能会明显下降.这种情况不符合被广泛遵循的雅虎军规"尽量减少HTTP请求数"的要求(雅虎前端优化的35条军规). 为了避免这种情况,我们需要使用到雪碧图将这些图标整合到一张图片上,再使用CSS背景及其定位,将需要显示的图标移动到元素背景中. 传统方式,我们需要将图标拼接到一张图片上,计算好位置信息,这种方式维护起

  • PHP好看的版权信息注释图型实例详解

    1.神兽 <?php /** * ┏┻━━━━━┻┓ * ┃ ┃ * ┃ ┳┛ ┗┳ ┃ * ┃ ┻ ┃ * ┗━┓ ┏━━━┛ * ┃ ┃神兽 保佑 * ┃ ┃代码无BUG * ┃ ┗━━━━━━━━━┓ * ┃ 我们 jb51.net ┣┓ * ┃ ┏┛ * ┗━┓ ┏━━━┓ ┏┛ * ┗━┛ ┗━┛ */ ?> 2.佛祖 /// // _ooOoo_ // // o8888888o // // 88" . "88 // // (| ^_^ |) // // O\ =

  • Java数据结构中图的进阶详解

    目录 有向图 有向图API设计 有向图的实现 拓扑排序 拓扑排序图解 检测有向图中的环 检测有向环的API设计 检测有向环实现 代码 基于深度优先的顶点排序 顶点排序API设计 顶点排序实现 代码: 有向图 有向图的定义及相关术语 定义∶ 有向图是一副具有方向性的图,是由一组顶点和一组有方向的边组成的,每条方向的边都连着 一对有序的顶点. 出度∶ 由某个顶点指出的边的个数称为该顶点的出度. 入度: 指向某个顶点的边的个数称为该顶点的入度. 有向路径︰ 由一系列顶点组成,对于其中的每个顶点都存在一

随机推荐