有10张图片每张图片包含20个节点一共66条边每行为一条边源节点和目标节点用空格隔开格式为: 0 1 0 2 0 3 0 4 1 2 1 3 1 9 2 3 2 5 2 8 3 6 3 8 3 10 3 13 3 15 3 18 4 5 4 9 4 10 5 11 5 12 5 13 6 16 6 17 6 19 7 8 7 13 8 13 8 18 9 10 9 16 10 11 10 15 1
以下是用PYG库自定义数据集My_dataset类,实现数据加载和划分训练集和验证集的代码:
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import from_networkx
import networkx as nx
import numpy as np
class My_dataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(My_dataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['edges.txt', 'features1.txt', 'features2.txt', 'label.txt']
@property
def processed_file_names(self):
return 'processed.pt'
def process(self):
# Load raw data from file
edges = np.loadtxt('C:/Users/jh/Desktop/data/raw/edges.txt')
features1 = np.loadtxt('C:/Users/jh/Desktop/data/raw/features1.txt')
features2 = np.loadtxt('C:/Users/jh/Desktop/data/raw/features2.txt')
labels = np.loadtxt('C:/Users/jh/Desktop/data/raw/label.txt')
# Create graph
graph = nx.Graph()
graph.add_edges_from(edges)
# Create node features and labels
x = np.concatenate((features1.reshape(-1, 1), features2.reshape(-1, 1)), axis=1)
y = labels.reshape(-1, 1)
# Create PyG Data object
data = from_networkx(graph, node_features=x, node_labels=y)
# Split data into training and validation sets
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
train_mask[:16] = 1
val_mask = ~train_mask
data.train_mask = train_mask
data.val_mask = val_mask
data = data if self.pre_transform is None else self.pre_transform(data)
# Save processed data
torch.save(self.collate([data]), self.processed_paths[0])
然后,可以使用GCN模型对数据进行训练和验证:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return x
# Load the dataset
dataset = My_dataset(root='.')
# Define the model
model = GCN(num_features=2, hidden_channels=16, num_classes=2)
# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Create data loader
loader = DataLoader(dataset, batch_size=1, shuffle=True)
# Training loop
for epoch in range(100):
for batch in loader:
model.train()
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
# Validation loop
model.eval()
correct = 0
total = 0
for batch in loader:
with torch.no_grad():
out = model(batch.x, batch.edge_index)
_, predicted = torch.max(out[batch.val_mask], 1)
total += batch.val_mask.sum().item()
correct += (predicted == batch.y[batch.val_mask]).sum().item()
accuracy = correct / total
print(f"Validation Accuracy: {accuracy}")
请注意,上述代码中的路径'C:/Users/jh/Desktop/data/raw'应该根据实际情况进行更改,以便正确加载数据文件。此外,模型的参数和网络结构也可以根据实际需求进行更改。
原文地址: https://www.cveoy.top/t/topic/i5A2 著作权归作者所有。请勿转载和采集!