最近在学习图神经网络,进行了一个小小的demo练习,学习了一下怎么建图,怎么进行图神经网络的训练~

1、库的导入

%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch.nn import Linear
from torch_geometric.nn import GCNConv
import torch
import torch.nn as nn

这一部分需要首先安装torch和torch_geometric,不会安装的小伙伴可以看我之前的博客,需要注意一些版本之间的关系。

torch_geometric踩坑实战--安装与运行 亲测有效!!_汤汤upup的博客-CSDN博客

 2、数据集的导入及查看

torch_geometric中有一些自带数据集,这次用到的是KarateClub空手道俱乐部的数据集

from torch_geometric.datasets import KarateClub
dataset=KarateClub()

data=dataset[0]
data

输出:

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

数据集信息如下:

总共34个节点,每个节点34个特征,156条边,节点分为4个类别,任务是对节点进行分类,判断节点属于哪个小团体。因此输出中的x=[34,34]分别代表每个节点的特征个数,以及节点个数,edge_index=[2,156],2是该属性中的不变数值,代表两个节点相连,156代表边的条数。

3、进行数据集的可视化

def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap="Set2")
    plt.show()

from torch_geometric.utils import to_networkx
G=to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)

利用networkx对Data中的数据进行可视化,得到的可视化结果如下:

4、模型的训练

搭建两层的GCN模型,搭建过程非常简单,不会的同学建议看一下pytorch基础~

import torch.nn.functional as F
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = F.softmax(x, dim=1)

        return x
model = GCN(dataset.num_node_features, dataset.num_classes)
print(model)

---------------------------------------------------------------
GCN(
  (conv1): GCNConv(34, 16)
  (conv2): GCNConv(16, 4)
)

对模型进行训练

def train(model, data):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    loss_function = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(200):
        out = model(data)
        optimizer.zero_grad()
        loss = loss_function(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))

train(model,data)

 训练结果如下:

一个简单的图神经网络学习的demo就完成啦,当然如果需要实用自己的数据集也可以,可以使用如下代码:

from torch_geometric.data import Data
x = torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtype=torch.float)
y = torch.tensor([0,1,0,1],dtype=torch.float)

edge_index = torch.tensor([[0,1,2,0,3],
                           [1,0,1,3,2]],dtype=torch.long)
data = Data(x=x,y=y,edge_index=edge_index)

其中x代表节点的特征,可以看出特征的维度是2*1,y代表节点的标签,edge_index代表边的连接,是一个2*n维的矩阵,表示节点0和节点1相连,等等

建图的结果如下:

本文主要介绍的是节点分类算法,当然图分类和边分类大同小异~ 

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐