Graph Neural Network with Feature Smoothing for Link Prediction
from future import division from future import print_function
Make a mark
import argparse import time
import numpy as np import scipy.sparse as sp #import torch #from torch import optim #import torch.nn as nn import torch.nn.functional as F #import networkx as nx
from utils import *
parser = argparse.ArgumentParser() # Build a parameter parser parser.add_argument('--seed', type=int, default=42, help='Random seed.') # Add the parameter seed to specify the random number seed parser.add_argument('--dataset', type=str, default='cora', help='type of dataset.') # Add the parameter dataset to specify the type of dataset parser.add_argument('--hops', type=int, default=20, help='number of hops') # Add the parameter hops to specify the number of iterations
args = parser.parse_args() # Parse the parameters
def run(args): # 1. Load the data and process the adjacency matrix print("Using {} dataset".format(args.dataset)) # Print the type of dataset if args.dataset == 'wiki': adj, features, y, _ = load_wiki() # Load the Wikipedia dataset else: adj, features, y, _, _, _, _ = load_data(args.dataset) # Load the specified dataset n_nodes, feat_dim = features.shape # Get the number of samples and the feature dimension
adj_orig = adj # Backup the adjacency matrix
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape) # Subtract the matrix to remove self-connections
adj_orig.eliminate_zeros() # Remove zero elements
# 2. Feature Smoothing
adj_train, _, _, _, test_edges, test_edges_false = mask_test_edges(adj) # Randomly mask some elements of the adjacency matrix as the test set, the rest as the training set
adj = adj_train # Use the training set as the adjacency matrix
for hop in range(args.hops, args.hops+1): # hop is the number of iterations n
# shezhi
input_features = 0. # Initialize the input feature matrix
if args.dataset == 'pubmed':
r_list = [0.3, 0.4, 0.5] # For the PubMed dataset, set the range of regularization parameters r
else:
r_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5] # For other datasets, set the range of regularization parameters r
for r in r_list:
adj_norm = normalize_adj(adj, r) # Normalize the adjacency matrix
features_list = []
features_list.append(features) # Input feature matrix X
for _ in range(hop):
features_list.append(torch.spmm(adj_norm, features_list[-1])) # Matrix multiplication A*X
weight_list = []
norm_fea = torch.norm(features, 2, 1).add(1e-10) # L2 regularization of the input feature matrix
for fea in features_list:
norm_cur = torch.norm(fea, 2, 1).add(1e-10) # L2 regularization of the feature matrix of each layer
# Calculate the weight
temp = torch.div((features*fea).sum(1), norm_fea) # Calculate the similarity between the input feature matrix and the feature matrix of the current layer
temp = torch.div(temp, norm_cur) # L2 regularization of the similarity
weight_list.append(temp.unsqueeze(-1))
weight = F.softmax(torch.cat(weight_list, dim=1), dim=1) # Normalize the weight
input_feas = []
#
for i in range(n_nodes):
fea = 0.
for j in range(hop+1):
fea += (weight[i][j]*features_list[j][i]).unsqueeze(0) # Calculate the input features of each node using a weighted average
input_feas.append(fea)
input_feas = torch.cat(input_feas, dim=0) # Concatenate the input features of each node
if r == r_list[0]:
input_features = input_feas # Initialize the input feature matrix
else:
## (input_features,input_feas)
temp = []
temp.append(input_features.unsqueeze(0))
temp.append(input_feas.unsqueeze(0))
input_features = torch.cat(temp, dim=0).max(0)[0]# Concatenate the input feature matrices obtained in each iteration and take the maximum value as the final input feature matrix
sim = torch.sigmoid(torch.mm(input_features, input_features.T)) # Calculate the similarity matrix
roc_score, ap_score = get_roc_score(sim.numpy(), adj_orig, test_edges, test_edges_false) # Calculate AUC and AP
print(f'AUC: {roc_score:.4f}, AP: {ap_score:.4f}, Hop: {hop:02d}') # Print evaluation indicators
if name == 'main': set_seed(args.seed) # Set the random number seed run(args) # Run the main function
原文地址: https://www.cveoy.top/t/topic/nsIl 著作权归作者所有。请勿转载和采集!